Feat: add deepseek integration [WIP] (only completion works)

This commit is contained in:
Grail Finder
2025-02-28 16:16:11 +03:00
parent 49409f5d94
commit fd1ac24d75
5 changed files with 329 additions and 23 deletions

93
bot.go
View File

@@ -2,6 +2,8 @@ package main
import (
"bufio"
"bytes"
"context"
"elefant/config"
"elefant/models"
"elefant/rag"
@@ -10,6 +12,7 @@ import (
"fmt"
"io"
"log/slog"
"net"
"net/http"
"os"
"path"
@@ -20,7 +23,30 @@ import (
"github.com/rivo/tview"
)
var httpClient = http.Client{}
var httpClient = &http.Client{}
func createClient(connectTimeout time.Duration) *http.Client {
// Custom transport with connection timeout
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// Create a dialer with connection timeout
dialer := &net.Dialer{
Timeout: connectTimeout,
KeepAlive: 30 * time.Second, // Optional
}
return dialer.DialContext(ctx, network, addr)
},
// Other transport settings (optional)
TLSHandshakeTimeout: connectTimeout,
ResponseHeaderTimeout: connectTimeout,
}
// Client with no overall timeout (or set to streaming-safe duration)
return &http.Client{
Transport: transport,
Timeout: 0, // No overall timeout (for streaming)
}
}
var (
cfg *config.Config
@@ -36,7 +62,6 @@ var (
defaultStarterBytes = []byte{}
interruptResp = false
ragger *rag.RAG
currentModel = "none"
chunkParser ChunkParser
defaultLCPProps = map[string]float32{
"temperature": 0.8,
@@ -47,6 +72,7 @@ var (
)
func fetchModelName() *models.LLMModels {
// TODO: to config
api := "http://localhost:8080/v1/models"
//nolint
resp, err := httpClient.Get(api)
@@ -61,16 +87,44 @@ func fetchModelName() *models.LLMModels {
return nil
}
if resp.StatusCode != 200 {
currentModel = "disconnected"
chatBody.Model = "disconnected"
return nil
}
currentModel = path.Base(llmModel.Data[0].ID)
chatBody.Model = path.Base(llmModel.Data[0].ID)
return &llmModel
}
func fetchDSBalance() *models.DSBalance {
url := "https://api.deepseek.com/user/balance"
method := "GET"
req, err := http.NewRequest(method, url, nil)
if err != nil {
logger.Warn("failed to create request", "error", err)
return nil
}
req.Header.Add("Accept", "application/json")
req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken)
res, err := httpClient.Do(req)
if err != nil {
logger.Warn("failed to make request", "error", err)
return nil
}
defer res.Body.Close()
resp := models.DSBalance{}
if err := json.NewDecoder(res.Body).Decode(&resp); err != nil {
return nil
}
return &resp
}
func sendMsgToLLM(body io.Reader) {
choseChunkParser()
req, err := http.NewRequest("POST", cfg.CurrentAPI, body)
req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken)
// nolint
resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
// resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
if err != nil {
logger.Error("llamacpp api", "error", err)
if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
@@ -79,6 +133,16 @@ func sendMsgToLLM(body io.Reader) {
streamDone <- true
return
}
resp, err := httpClient.Do(req)
if err != nil {
bodyBytes, _ := io.ReadAll(body)
logger.Error("llamacpp api", "error", err, "body", string(bodyBytes))
if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
logger.Error("failed to notify", "error", err)
}
streamDone <- true
return
}
defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)
counter := uint32(0)
@@ -113,6 +177,10 @@ func sendMsgToLLM(body io.Reader) {
// starts with -> data:
line = line[6:]
logger.Debug("debugging resp", "line", string(line))
if bytes.Equal(line, []byte("[DONE]\n")) {
streamDone <- true
break
}
content, stop, err = chunkParser.ParseChunk(line)
if err != nil {
logger.Error("error parsing response body", "error", err, "line", string(line), "url", cfg.CurrentAPI)
@@ -185,7 +253,17 @@ func roleToIcon(role string) string {
func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
botRespMode = true
// reader := formMsg(chatBody, userMsg, role)
defer func() { botRespMode = false }()
// check that there is a model set to use if is not local
if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI {
if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" {
if err := notifyUser("bad request", "wrong deepseek model name"); err != nil {
logger.Warn("failed ot notify user", "error", err)
return
}
return
}
}
reader, err := chunkParser.FormMsg(userMsg, role, resume)
if reader == nil || err != nil {
logger.Error("empty reader from msgs", "role", role, "error", err)
@@ -369,7 +447,8 @@ func init() {
Stream: true,
Messages: lastChat,
}
initChunkParser()
choseChunkParser()
httpClient = createClient(time.Second * 15)
// go runModelNameTicker(time.Second * 120)
// tempLoad()
}

View File

@@ -10,6 +10,7 @@ type Config struct {
ChatAPI string `toml:"ChatAPI"`
CompletionAPI string `toml:"CompletionAPI"`
CurrentAPI string
CurrentProvider string
APIMap map[string]string
//
ShowSys bool `toml:"ShowSys"`
@@ -30,6 +31,12 @@ type Config struct {
RAGWorkers uint32 `toml:"RAGWorkers"`
RAGBatchSize int `toml:"RAGBatchSize"`
RAGWordLimit uint32 `toml:"RAGWordLimit"`
// deepseek
DeepSeekChatAPI string `toml:"DeepSeekChatAPI"`
DeepSeekCompletionAPI string `toml:"DeepSeekCompletionAPI"`
DeepSeekToken string `toml:"DeepSeekToken"`
DeepSeekModel string `toml:"DeepSeekModel"`
ApiLinks []string
}
func LoadConfigOrDefault(fn string) *Config {
@@ -39,9 +46,11 @@ func LoadConfigOrDefault(fn string) *Config {
config := &Config{}
_, err := toml.DecodeFile(fn, &config)
if err != nil {
fmt.Println("failed to read config from file, loading default")
fmt.Println("failed to read config from file, loading default", "error", err)
config.ChatAPI = "http://localhost:8080/v1/chat/completions"
config.CompletionAPI = "http://localhost:8080/completion"
config.DeepSeekCompletionAPI = "https://api.deepseek.com/beta/completions"
config.DeepSeekChatAPI = "https://api.deepseek.com/chat/completions"
config.RAGEnabled = false
config.EmbedURL = "http://localhost:8080/v1/embiddings"
config.ShowSys = true
@@ -59,11 +68,15 @@ func LoadConfigOrDefault(fn string) *Config {
config.CurrentAPI = config.ChatAPI
config.APIMap = map[string]string{
config.ChatAPI: config.CompletionAPI,
config.DeepSeekChatAPI: config.DeepSeekCompletionAPI,
}
if config.CompletionAPI != "" {
config.CurrentAPI = config.CompletionAPI
config.APIMap = map[string]string{
config.CompletionAPI: config.ChatAPI,
config.APIMap[config.CompletionAPI] = config.ChatAPI
}
for _, el := range []string{config.ChatAPI, config.CompletionAPI, config.DeepSeekChatAPI, config.DeepSeekCompletionAPI} {
if el != "" {
config.ApiLinks = append(config.ApiLinks, el)
}
}
// if any value is empty fill with default

90
llm.go
View File

@@ -13,20 +13,32 @@ type ChunkParser interface {
FormMsg(msg, role string, cont bool) (io.Reader, error)
}
func initChunkParser() {
func choseChunkParser() {
chunkParser = LlamaCPPeer{}
if strings.Contains(cfg.CurrentAPI, "v1") {
logger.Debug("chosen /v1/chat parser")
switch cfg.CurrentAPI {
case "http://localhost:8080/completion":
chunkParser = LlamaCPPeer{}
case "http://localhost:8080/v1/chat/completions":
chunkParser = OpenAIer{}
return
case "https://api.deepseek.com/beta/completions":
chunkParser = DeepSeeker{}
default:
chunkParser = LlamaCPPeer{}
}
logger.Debug("chosen llamacpp /completion parser")
// if strings.Contains(cfg.CurrentAPI, "chat") {
// logger.Debug("chosen chat parser")
// chunkParser = OpenAIer{}
// return
// }
// logger.Debug("chosen llamacpp /completion parser")
}
type LlamaCPPeer struct {
}
type OpenAIer struct {
}
type DeepSeeker struct {
}
func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
if msg != "" { // otherwise let the bot to continue
@@ -62,7 +74,12 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error)
}
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
"msg", msg, "resume", resume, "prompt", prompt)
payload := models.NewLCPReq(prompt, cfg, defaultLCPProps)
var payload any
payload = models.NewLCPReq(prompt, cfg, defaultLCPProps)
if strings.Contains(chatBody.Model, "deepseek") {
payload = models.NewDSCompletionReq(prompt, chatBody.Model,
defaultLCPProps["temp"], cfg)
}
data, err := json.Marshal(payload)
if err != nil {
logger.Error("failed to form a msg", "error", err)
@@ -129,3 +146,64 @@ func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
}
return bytes.NewReader(data), nil
}
// deepseek
func (ds DeepSeeker) ParseChunk(data []byte) (string, bool, error) {
llmchunk := models.DSCompletionResp{}
if err := json.Unmarshal(data, &llmchunk); err != nil {
logger.Error("failed to decode", "error", err, "line", string(data))
return "", false, err
}
if llmchunk.Choices[0].FinishReason != "" {
if llmchunk.Choices[0].Text != "" {
logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
}
return llmchunk.Choices[0].Text, true, nil
}
return llmchunk.Choices[0].Text, false, nil
}
func (ds DeepSeeker) FormMsg(msg, role string, resume bool) (io.Reader, error) {
if msg != "" { // otherwise let the bot to continue
newMsg := models.RoleMsg{Role: role, Content: msg}
chatBody.Messages = append(chatBody.Messages, newMsg)
// if rag
if cfg.RAGEnabled {
ragResp, err := chatRagUse(newMsg.Content)
if err != nil {
logger.Error("failed to form a rag msg", "error", err)
return nil, err
}
ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp}
chatBody.Messages = append(chatBody.Messages, ragMsg)
}
}
if cfg.ToolUse && !resume {
// add to chat body
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
}
messages := make([]string, len(chatBody.Messages))
for i, m := range chatBody.Messages {
messages[i] = m.ToPrompt()
}
prompt := strings.Join(messages, "\n")
// strings builder?
if !resume {
botMsgStart := "\n" + cfg.AssistantRole + ":\n"
prompt += botMsgStart
}
if cfg.ThinkUse && !cfg.ToolUse {
prompt += "<think>"
}
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
"msg", msg, "resume", resume, "prompt", prompt)
var payload any
payload = models.NewDSCompletionReq(prompt, chatBody.Model,
defaultLCPProps["temp"], cfg)
data, err := json.Marshal(payload)
if err != nil {
logger.Error("failed to form a msg", "error", err)
return nil, err
}
return bytes.NewReader(data), nil
}

View File

@@ -103,6 +103,126 @@ type ChatToolsBody struct {
ToolChoice string `json:"tool_choice"`
}
type DSChatReq struct {
Messages []RoleMsg `json:"messages"`
Model string `json:"model"`
Stream bool `json:"stream"`
FrequencyPenalty int `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
PresencePenalty int `json:"presence_penalty"`
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
// ResponseFormat struct {
// Type string `json:"type"`
// } `json:"response_format"`
// Stop any `json:"stop"`
// StreamOptions any `json:"stream_options"`
// Tools any `json:"tools"`
// ToolChoice string `json:"tool_choice"`
// Logprobs bool `json:"logprobs"`
// TopLogprobs any `json:"top_logprobs"`
}
func NewDSCharReq(cb *ChatBody) DSChatReq {
return DSChatReq{
Messages: cb.Messages,
Model: cb.Model,
Stream: cb.Stream,
MaxTokens: 2048,
PresencePenalty: 0,
FrequencyPenalty: 0,
Temperature: 1.0,
TopP: 1.0,
}
}
type DSCompletionReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Echo bool `json:"echo"`
FrequencyPenalty int `json:"frequency_penalty"`
// Logprobs int `json:"logprobs"`
MaxTokens int `json:"max_tokens"`
PresencePenalty int `json:"presence_penalty"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
StreamOptions any `json:"stream_options"`
Suffix any `json:"suffix"`
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
}
func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config) DSCompletionReq {
return DSCompletionReq{
Model: model,
Prompt: prompt,
Temperature: temp,
Stream: true,
Echo: false,
MaxTokens: 2048,
PresencePenalty: 0,
FrequencyPenalty: 0,
TopP: 1.0,
Stop: []string{
cfg.UserRole + ":\n", "<|im_end|>",
cfg.ToolRole + ":\n",
cfg.AssistantRole + ":\n",
},
}
}
type DSCompletionResp struct {
ID string `json:"id"`
Choices []struct {
FinishReason string `json:"finish_reason"`
Index int `json:"index"`
Logprobs struct {
TextOffset []int `json:"text_offset"`
TokenLogprobs []int `json:"token_logprobs"`
Tokens []string `json:"tokens"`
TopLogprobs []struct {
} `json:"top_logprobs"`
} `json:"logprobs"`
Text string `json:"text"`
} `json:"choices"`
Created int `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Object string `json:"object"`
Usage struct {
CompletionTokens int `json:"completion_tokens"`
PromptTokens int `json:"prompt_tokens"`
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens"`
PromptCacheMissTokens int `json:"prompt_cache_miss_tokens"`
TotalTokens int `json:"total_tokens"`
CompletionTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
} `json:"completion_tokens_details"`
} `json:"usage"`
}
type DSChatResp struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
Role any `json:"role"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
Index int `json:"index"`
Logprobs any `json:"logprobs"`
} `json:"choices"`
Created int `json:"created"`
ID string `json:"id"`
Model string `json:"model"`
Object string `json:"object"`
SystemFingerprint string `json:"system_fingerprint"`
Usage struct {
CompletionTokens int `json:"completion_tokens"`
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
type EmbeddingResp struct {
Embedding []float32 `json:"embedding"`
Index uint32 `json:"index"`
@@ -190,3 +310,13 @@ type LlamaCPPResp struct {
Content string `json:"content"`
Stop bool `json:"stop"`
}
type DSBalance struct {
IsAvailable bool `json:"is_available"`
BalanceInfos []struct {
Currency string `json:"currency"`
TotalBalance string `json:"total_balance"`
GrantedBalance string `json:"granted_balance"`
ToppedUpBalance string `json:"topped_up_balance"`
} `json:"balance_infos"`
}

10
tui.go
View File

@@ -136,7 +136,7 @@ func colorText() {
}
func updateStatusLine() {
position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, currentModel, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level()))
position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, chatBody.Model, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level()))
}
func initSysCards() ([]string, error) {
@@ -202,6 +202,12 @@ func makePropsForm(props map[string]float32) *tview.Form {
}).AddDropDown("Set log level (Enter): ", []string{"Debug", "Info", "Warn"}, 1,
func(option string, optionIndex int) {
setLogLevel(option)
}).AddDropDown("Select an api: ", cfg.ApiLinks, 1,
func(option string, optionIndex int) {
cfg.CurrentAPI = option
}).AddDropDown("Select a model: ", []string{chatBody.Model, "deepseek-chat", "deepseek-reasoner"}, 0,
func(option string, optionIndex int) {
chatBody.Model = option
}).
AddButton("Quit", func() {
pages.RemovePage(propsPage)
@@ -600,7 +606,7 @@ func init() {
}
cfg.APIMap[newAPI] = prevAPI
cfg.CurrentAPI = newAPI
initChunkParser()
choseChunkParser()
updateStatusLine()
return nil
}