Feat: add deepseek integration [WIP] (only completion works)

This commit is contained in:
Grail Finder
2025-02-28 16:16:11 +03:00
parent 49409f5d94
commit fd1ac24d75
5 changed files with 329 additions and 23 deletions

93
bot.go
View File

@@ -2,6 +2,8 @@ package main
import ( import (
"bufio" "bufio"
"bytes"
"context"
"elefant/config" "elefant/config"
"elefant/models" "elefant/models"
"elefant/rag" "elefant/rag"
@@ -10,6 +12,7 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"net"
"net/http" "net/http"
"os" "os"
"path" "path"
@@ -20,7 +23,30 @@ import (
"github.com/rivo/tview" "github.com/rivo/tview"
) )
var httpClient = http.Client{} var httpClient = &http.Client{}
func createClient(connectTimeout time.Duration) *http.Client {
// Custom transport with connection timeout
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// Create a dialer with connection timeout
dialer := &net.Dialer{
Timeout: connectTimeout,
KeepAlive: 30 * time.Second, // Optional
}
return dialer.DialContext(ctx, network, addr)
},
// Other transport settings (optional)
TLSHandshakeTimeout: connectTimeout,
ResponseHeaderTimeout: connectTimeout,
}
// Client with no overall timeout (or set to streaming-safe duration)
return &http.Client{
Transport: transport,
Timeout: 0, // No overall timeout (for streaming)
}
}
var ( var (
cfg *config.Config cfg *config.Config
@@ -36,7 +62,6 @@ var (
defaultStarterBytes = []byte{} defaultStarterBytes = []byte{}
interruptResp = false interruptResp = false
ragger *rag.RAG ragger *rag.RAG
currentModel = "none"
chunkParser ChunkParser chunkParser ChunkParser
defaultLCPProps = map[string]float32{ defaultLCPProps = map[string]float32{
"temperature": 0.8, "temperature": 0.8,
@@ -47,6 +72,7 @@ var (
) )
func fetchModelName() *models.LLMModels { func fetchModelName() *models.LLMModels {
// TODO: to config
api := "http://localhost:8080/v1/models" api := "http://localhost:8080/v1/models"
//nolint //nolint
resp, err := httpClient.Get(api) resp, err := httpClient.Get(api)
@@ -61,16 +87,44 @@ func fetchModelName() *models.LLMModels {
return nil return nil
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
currentModel = "disconnected" chatBody.Model = "disconnected"
return nil return nil
} }
currentModel = path.Base(llmModel.Data[0].ID) chatBody.Model = path.Base(llmModel.Data[0].ID)
return &llmModel return &llmModel
} }
func fetchDSBalance() *models.DSBalance {
url := "https://api.deepseek.com/user/balance"
method := "GET"
req, err := http.NewRequest(method, url, nil)
if err != nil {
logger.Warn("failed to create request", "error", err)
return nil
}
req.Header.Add("Accept", "application/json")
req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken)
res, err := httpClient.Do(req)
if err != nil {
logger.Warn("failed to make request", "error", err)
return nil
}
defer res.Body.Close()
resp := models.DSBalance{}
if err := json.NewDecoder(res.Body).Decode(&resp); err != nil {
return nil
}
return &resp
}
func sendMsgToLLM(body io.Reader) { func sendMsgToLLM(body io.Reader) {
choseChunkParser()
req, err := http.NewRequest("POST", cfg.CurrentAPI, body)
req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken)
// nolint // nolint
resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body) // resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
if err != nil { if err != nil {
logger.Error("llamacpp api", "error", err) logger.Error("llamacpp api", "error", err)
if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
@@ -79,6 +133,16 @@ func sendMsgToLLM(body io.Reader) {
streamDone <- true streamDone <- true
return return
} }
resp, err := httpClient.Do(req)
if err != nil {
bodyBytes, _ := io.ReadAll(body)
logger.Error("llamacpp api", "error", err, "body", string(bodyBytes))
if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
logger.Error("failed to notify", "error", err)
}
streamDone <- true
return
}
defer resp.Body.Close() defer resp.Body.Close()
reader := bufio.NewReader(resp.Body) reader := bufio.NewReader(resp.Body)
counter := uint32(0) counter := uint32(0)
@@ -113,6 +177,10 @@ func sendMsgToLLM(body io.Reader) {
// starts with -> data: // starts with -> data:
line = line[6:] line = line[6:]
logger.Debug("debugging resp", "line", string(line)) logger.Debug("debugging resp", "line", string(line))
if bytes.Equal(line, []byte("[DONE]\n")) {
streamDone <- true
break
}
content, stop, err = chunkParser.ParseChunk(line) content, stop, err = chunkParser.ParseChunk(line)
if err != nil { if err != nil {
logger.Error("error parsing response body", "error", err, "line", string(line), "url", cfg.CurrentAPI) logger.Error("error parsing response body", "error", err, "line", string(line), "url", cfg.CurrentAPI)
@@ -185,7 +253,17 @@ func roleToIcon(role string) string {
func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) { func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
botRespMode = true botRespMode = true
// reader := formMsg(chatBody, userMsg, role) defer func() { botRespMode = false }()
// check that there is a model set to use if is not local
if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI {
if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" {
if err := notifyUser("bad request", "wrong deepseek model name"); err != nil {
logger.Warn("failed ot notify user", "error", err)
return
}
return
}
}
reader, err := chunkParser.FormMsg(userMsg, role, resume) reader, err := chunkParser.FormMsg(userMsg, role, resume)
if reader == nil || err != nil { if reader == nil || err != nil {
logger.Error("empty reader from msgs", "role", role, "error", err) logger.Error("empty reader from msgs", "role", role, "error", err)
@@ -369,7 +447,8 @@ func init() {
Stream: true, Stream: true,
Messages: lastChat, Messages: lastChat,
} }
initChunkParser() choseChunkParser()
httpClient = createClient(time.Second * 15)
// go runModelNameTicker(time.Second * 120) // go runModelNameTicker(time.Second * 120)
// tempLoad() // tempLoad()
} }

View File

@@ -10,6 +10,7 @@ type Config struct {
ChatAPI string `toml:"ChatAPI"` ChatAPI string `toml:"ChatAPI"`
CompletionAPI string `toml:"CompletionAPI"` CompletionAPI string `toml:"CompletionAPI"`
CurrentAPI string CurrentAPI string
CurrentProvider string
APIMap map[string]string APIMap map[string]string
// //
ShowSys bool `toml:"ShowSys"` ShowSys bool `toml:"ShowSys"`
@@ -30,6 +31,12 @@ type Config struct {
RAGWorkers uint32 `toml:"RAGWorkers"` RAGWorkers uint32 `toml:"RAGWorkers"`
RAGBatchSize int `toml:"RAGBatchSize"` RAGBatchSize int `toml:"RAGBatchSize"`
RAGWordLimit uint32 `toml:"RAGWordLimit"` RAGWordLimit uint32 `toml:"RAGWordLimit"`
// deepseek
DeepSeekChatAPI string `toml:"DeepSeekChatAPI"`
DeepSeekCompletionAPI string `toml:"DeepSeekCompletionAPI"`
DeepSeekToken string `toml:"DeepSeekToken"`
DeepSeekModel string `toml:"DeepSeekModel"`
ApiLinks []string
} }
func LoadConfigOrDefault(fn string) *Config { func LoadConfigOrDefault(fn string) *Config {
@@ -39,9 +46,11 @@ func LoadConfigOrDefault(fn string) *Config {
config := &Config{} config := &Config{}
_, err := toml.DecodeFile(fn, &config) _, err := toml.DecodeFile(fn, &config)
if err != nil { if err != nil {
fmt.Println("failed to read config from file, loading default") fmt.Println("failed to read config from file, loading default", "error", err)
config.ChatAPI = "http://localhost:8080/v1/chat/completions" config.ChatAPI = "http://localhost:8080/v1/chat/completions"
config.CompletionAPI = "http://localhost:8080/completion" config.CompletionAPI = "http://localhost:8080/completion"
config.DeepSeekCompletionAPI = "https://api.deepseek.com/beta/completions"
config.DeepSeekChatAPI = "https://api.deepseek.com/chat/completions"
config.RAGEnabled = false config.RAGEnabled = false
config.EmbedURL = "http://localhost:8080/v1/embiddings" config.EmbedURL = "http://localhost:8080/v1/embiddings"
config.ShowSys = true config.ShowSys = true
@@ -59,11 +68,15 @@ func LoadConfigOrDefault(fn string) *Config {
config.CurrentAPI = config.ChatAPI config.CurrentAPI = config.ChatAPI
config.APIMap = map[string]string{ config.APIMap = map[string]string{
config.ChatAPI: config.CompletionAPI, config.ChatAPI: config.CompletionAPI,
config.DeepSeekChatAPI: config.DeepSeekCompletionAPI,
} }
if config.CompletionAPI != "" { if config.CompletionAPI != "" {
config.CurrentAPI = config.CompletionAPI config.CurrentAPI = config.CompletionAPI
config.APIMap = map[string]string{ config.APIMap[config.CompletionAPI] = config.ChatAPI
config.CompletionAPI: config.ChatAPI, }
for _, el := range []string{config.ChatAPI, config.CompletionAPI, config.DeepSeekChatAPI, config.DeepSeekCompletionAPI} {
if el != "" {
config.ApiLinks = append(config.ApiLinks, el)
} }
} }
// if any value is empty fill with default // if any value is empty fill with default

90
llm.go
View File

@@ -13,20 +13,32 @@ type ChunkParser interface {
FormMsg(msg, role string, cont bool) (io.Reader, error) FormMsg(msg, role string, cont bool) (io.Reader, error)
} }
func initChunkParser() { func choseChunkParser() {
chunkParser = LlamaCPPeer{} chunkParser = LlamaCPPeer{}
if strings.Contains(cfg.CurrentAPI, "v1") { switch cfg.CurrentAPI {
logger.Debug("chosen /v1/chat parser") case "http://localhost:8080/completion":
chunkParser = LlamaCPPeer{}
case "http://localhost:8080/v1/chat/completions":
chunkParser = OpenAIer{} chunkParser = OpenAIer{}
return case "https://api.deepseek.com/beta/completions":
chunkParser = DeepSeeker{}
default:
chunkParser = LlamaCPPeer{}
} }
logger.Debug("chosen llamacpp /completion parser") // if strings.Contains(cfg.CurrentAPI, "chat") {
// logger.Debug("chosen chat parser")
// chunkParser = OpenAIer{}
// return
// }
// logger.Debug("chosen llamacpp /completion parser")
} }
type LlamaCPPeer struct { type LlamaCPPeer struct {
} }
type OpenAIer struct { type OpenAIer struct {
} }
type DeepSeeker struct {
}
func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) { func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
if msg != "" { // otherwise let the bot to continue if msg != "" { // otherwise let the bot to continue
@@ -62,7 +74,12 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error)
} }
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse, logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
"msg", msg, "resume", resume, "prompt", prompt) "msg", msg, "resume", resume, "prompt", prompt)
payload := models.NewLCPReq(prompt, cfg, defaultLCPProps) var payload any
payload = models.NewLCPReq(prompt, cfg, defaultLCPProps)
if strings.Contains(chatBody.Model, "deepseek") {
payload = models.NewDSCompletionReq(prompt, chatBody.Model,
defaultLCPProps["temp"], cfg)
}
data, err := json.Marshal(payload) data, err := json.Marshal(payload)
if err != nil { if err != nil {
logger.Error("failed to form a msg", "error", err) logger.Error("failed to form a msg", "error", err)
@@ -129,3 +146,64 @@ func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
} }
return bytes.NewReader(data), nil return bytes.NewReader(data), nil
} }
// deepseek
func (ds DeepSeeker) ParseChunk(data []byte) (string, bool, error) {
llmchunk := models.DSCompletionResp{}
if err := json.Unmarshal(data, &llmchunk); err != nil {
logger.Error("failed to decode", "error", err, "line", string(data))
return "", false, err
}
if llmchunk.Choices[0].FinishReason != "" {
if llmchunk.Choices[0].Text != "" {
logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
}
return llmchunk.Choices[0].Text, true, nil
}
return llmchunk.Choices[0].Text, false, nil
}
func (ds DeepSeeker) FormMsg(msg, role string, resume bool) (io.Reader, error) {
if msg != "" { // otherwise let the bot to continue
newMsg := models.RoleMsg{Role: role, Content: msg}
chatBody.Messages = append(chatBody.Messages, newMsg)
// if rag
if cfg.RAGEnabled {
ragResp, err := chatRagUse(newMsg.Content)
if err != nil {
logger.Error("failed to form a rag msg", "error", err)
return nil, err
}
ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp}
chatBody.Messages = append(chatBody.Messages, ragMsg)
}
}
if cfg.ToolUse && !resume {
// add to chat body
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
}
messages := make([]string, len(chatBody.Messages))
for i, m := range chatBody.Messages {
messages[i] = m.ToPrompt()
}
prompt := strings.Join(messages, "\n")
// strings builder?
if !resume {
botMsgStart := "\n" + cfg.AssistantRole + ":\n"
prompt += botMsgStart
}
if cfg.ThinkUse && !cfg.ToolUse {
prompt += "<think>"
}
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
"msg", msg, "resume", resume, "prompt", prompt)
var payload any
payload = models.NewDSCompletionReq(prompt, chatBody.Model,
defaultLCPProps["temp"], cfg)
data, err := json.Marshal(payload)
if err != nil {
logger.Error("failed to form a msg", "error", err)
return nil, err
}
return bytes.NewReader(data), nil
}

View File

@@ -103,6 +103,126 @@ type ChatToolsBody struct {
ToolChoice string `json:"tool_choice"` ToolChoice string `json:"tool_choice"`
} }
type DSChatReq struct {
Messages []RoleMsg `json:"messages"`
Model string `json:"model"`
Stream bool `json:"stream"`
FrequencyPenalty int `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
PresencePenalty int `json:"presence_penalty"`
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
// ResponseFormat struct {
// Type string `json:"type"`
// } `json:"response_format"`
// Stop any `json:"stop"`
// StreamOptions any `json:"stream_options"`
// Tools any `json:"tools"`
// ToolChoice string `json:"tool_choice"`
// Logprobs bool `json:"logprobs"`
// TopLogprobs any `json:"top_logprobs"`
}
func NewDSCharReq(cb *ChatBody) DSChatReq {
return DSChatReq{
Messages: cb.Messages,
Model: cb.Model,
Stream: cb.Stream,
MaxTokens: 2048,
PresencePenalty: 0,
FrequencyPenalty: 0,
Temperature: 1.0,
TopP: 1.0,
}
}
type DSCompletionReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Echo bool `json:"echo"`
FrequencyPenalty int `json:"frequency_penalty"`
// Logprobs int `json:"logprobs"`
MaxTokens int `json:"max_tokens"`
PresencePenalty int `json:"presence_penalty"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
StreamOptions any `json:"stream_options"`
Suffix any `json:"suffix"`
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
}
func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config) DSCompletionReq {
return DSCompletionReq{
Model: model,
Prompt: prompt,
Temperature: temp,
Stream: true,
Echo: false,
MaxTokens: 2048,
PresencePenalty: 0,
FrequencyPenalty: 0,
TopP: 1.0,
Stop: []string{
cfg.UserRole + ":\n", "<|im_end|>",
cfg.ToolRole + ":\n",
cfg.AssistantRole + ":\n",
},
}
}
type DSCompletionResp struct {
ID string `json:"id"`
Choices []struct {
FinishReason string `json:"finish_reason"`
Index int `json:"index"`
Logprobs struct {
TextOffset []int `json:"text_offset"`
TokenLogprobs []int `json:"token_logprobs"`
Tokens []string `json:"tokens"`
TopLogprobs []struct {
} `json:"top_logprobs"`
} `json:"logprobs"`
Text string `json:"text"`
} `json:"choices"`
Created int `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Object string `json:"object"`
Usage struct {
CompletionTokens int `json:"completion_tokens"`
PromptTokens int `json:"prompt_tokens"`
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens"`
PromptCacheMissTokens int `json:"prompt_cache_miss_tokens"`
TotalTokens int `json:"total_tokens"`
CompletionTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
} `json:"completion_tokens_details"`
} `json:"usage"`
}
type DSChatResp struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
Role any `json:"role"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
Index int `json:"index"`
Logprobs any `json:"logprobs"`
} `json:"choices"`
Created int `json:"created"`
ID string `json:"id"`
Model string `json:"model"`
Object string `json:"object"`
SystemFingerprint string `json:"system_fingerprint"`
Usage struct {
CompletionTokens int `json:"completion_tokens"`
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
type EmbeddingResp struct { type EmbeddingResp struct {
Embedding []float32 `json:"embedding"` Embedding []float32 `json:"embedding"`
Index uint32 `json:"index"` Index uint32 `json:"index"`
@@ -190,3 +310,13 @@ type LlamaCPPResp struct {
Content string `json:"content"` Content string `json:"content"`
Stop bool `json:"stop"` Stop bool `json:"stop"`
} }
type DSBalance struct {
IsAvailable bool `json:"is_available"`
BalanceInfos []struct {
Currency string `json:"currency"`
TotalBalance string `json:"total_balance"`
GrantedBalance string `json:"granted_balance"`
ToppedUpBalance string `json:"topped_up_balance"`
} `json:"balance_infos"`
}

10
tui.go
View File

@@ -136,7 +136,7 @@ func colorText() {
} }
func updateStatusLine() { func updateStatusLine() {
position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, currentModel, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level())) position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, chatBody.Model, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level()))
} }
func initSysCards() ([]string, error) { func initSysCards() ([]string, error) {
@@ -202,6 +202,12 @@ func makePropsForm(props map[string]float32) *tview.Form {
}).AddDropDown("Set log level (Enter): ", []string{"Debug", "Info", "Warn"}, 1, }).AddDropDown("Set log level (Enter): ", []string{"Debug", "Info", "Warn"}, 1,
func(option string, optionIndex int) { func(option string, optionIndex int) {
setLogLevel(option) setLogLevel(option)
}).AddDropDown("Select an api: ", cfg.ApiLinks, 1,
func(option string, optionIndex int) {
cfg.CurrentAPI = option
}).AddDropDown("Select a model: ", []string{chatBody.Model, "deepseek-chat", "deepseek-reasoner"}, 0,
func(option string, optionIndex int) {
chatBody.Model = option
}). }).
AddButton("Quit", func() { AddButton("Quit", func() {
pages.RemovePage(propsPage) pages.RemovePage(propsPage)
@@ -600,7 +606,7 @@ func init() {
} }
cfg.APIMap[newAPI] = prevAPI cfg.APIMap[newAPI] = prevAPI
cfg.CurrentAPI = newAPI cfg.CurrentAPI = newAPI
initChunkParser() choseChunkParser()
updateStatusLine() updateStatusLine()
return nil return nil
} }