diff --git a/storage/storage.go b/storage/storage.go index 57631da..980501f 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,6 +1,7 @@ package storage import ( + "database/sql" "gf-lt/models" "log/slog" @@ -12,6 +13,12 @@ type FullRepo interface { ChatHistory Memories VectorRepo + TableLister +} + +type TableLister interface { + ListTables() ([]string, error) + GetTableColumns(table string) ([]TableColumn, error) } type ChatHistory interface { @@ -130,3 +137,24 @@ func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo { func (p ProviderSQL) DB() *sqlx.DB { return p.db } + +type TableColumn struct { + CID int `db:"cid"` + Name string `db:"name"` + Type string `db:"type"` + NotNull bool `db:"notnull"` + DFltVal sql.NullString `db:"dflt_value"` + PK int `db:"pk"` +} + +func (p ProviderSQL) ListTables() ([]string, error) { + resp := []string{} + err := p.db.Select(&resp, "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name;") + return resp, err +} + +func (p ProviderSQL) GetTableColumns(table string) ([]TableColumn, error) { + resp := []TableColumn{} + err := p.db.Select(&resp, "PRAGMA table_info("+table+");") + return resp, err +} diff --git a/tables.go b/tables.go index d175ae6..8494b2a 100644 --- a/tables.go +++ b/tables.go @@ -273,7 +273,7 @@ func makeRAGTable(fileList []string, loadedFiles []string) *tview.Flex { fileTable := tview.NewTable(). SetBorders(true) longStatusView := tview.NewTextView() - longStatusView.SetText("press x to exit") + longStatusView.SetText("press x to exit | press d to view DB") longStatusView.SetBorder(true).SetTitle("status") longStatusView.SetChangedFunc(func() { app.Draw() @@ -498,6 +498,14 @@ func makeRAGTable(fileList []string, loadedFiles []string) *tview.Flex { pages.RemovePage(RAGPage) return nil } + if event.Key() == tcell.KeyRune && event.Rune() == 'd' { + pages.RemovePage(RAGPage) + dbTable := makeDbTable() + if dbTable != nil { + pages.AddPage(dbTablesPage, dbTable, true, true) + } + return nil + } return event }) return ragflex @@ -1189,3 +1197,376 @@ func makeFilePicker() *tview.Flex { }) return flex } + +func makeDbTable() *tview.Flex { + tables, err := store.ListTables() + if err != nil { + logger.Error("failed to list tables", "error", err) + showToast("error", "failed to list tables: "+err.Error()) + return nil + } + if len(tables) == 0 { + showToast("info", "no tables found in database") + return nil + } + tblList := tview.NewList().ShowSecondaryText(false) + rowCounts := make(map[string]int) + for _, t := range tables { + var count int + _ = store.DB().Get(&count, "SELECT COUNT(*) FROM "+t) + rowCounts[t] = count + tblList.AddItem(t, fmt.Sprintf("%d rows", count), 0, nil) + } + tblList.SetBorder(true).SetTitle("Tables") + dataTable := tview.NewTable().SetBorders(true) + dataTable.SetBorder(true).SetTitle("Data") + flex := tview.NewFlex(). + AddItem(tblList, 0, 1, true). + AddItem(dataTable, 0, 2, false) + loadTableData := func(tableName string, tbl *tview.Table) { + rows, err := store.DB().Queryx("SELECT * FROM " + tableName + " LIMIT 80") + if err != nil { + logger.Error("failed to query table", "table", tableName, "error", err) + return + } + columnNames, _ := rows.Columns() + tbl.Clear() + for c, name := range columnNames { + tbl.SetCell(0, c, + tview.NewTableCell(name). + SetTextColor(tcell.ColorYellow). + SetAlign(tview.AlignCenter)) + } + r := 1 + for rows.Next() { + row := make(map[string]interface{}) + if err := rows.MapScan(row); err != nil { + continue + } + for c, name := range columnNames { + val, ok := row[name] + var cellText string + var color tcell.Color + if !ok || val == nil { + cellText = "NULL" + color = tcell.ColorDarkGray + } else { + cellText = fmt.Sprintf("%v", val) + if len(cellText) > 30 { + cellText = cellText[:30] + "..." + } + color = tcell.ColorWhite + } + tbl.SetCell(r, c, + tview.NewTableCell(cellText). + SetTextColor(color). + SetAlign(tview.AlignCenter)) + } + r++ + } + rows.Close() + tbl.Select(0, 0) + } + tblList.SetSelectedFunc(func(idx int, mainText, secondaryText string, rune rune) { + if idx >= 0 && idx < len(tables) { + loadTableData(tables[idx], dataTable) + dataTable.SetBorder(true).SetTitle("Data: " + tables[idx]) + } + }) + tblList.SetChangedFunc(func(idx int, mainText, secondaryText string, rune rune) { + if idx >= 0 && idx < len(tables) { + loadTableData(tables[idx], dataTable) + dataTable.SetBorder(true).SetTitle("Data: " + tables[idx]) + } + }) + tblList.SetDoneFunc(func() { + pages.RemovePage(dbTablesPage) + }) + tblList.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyRune && event.Rune() == 'x' { + pages.RemovePage(dbTablesPage) + app.SetFocus(textArea) + return nil + } + if event.Key() == tcell.KeyEnter { + idx := tblList.GetCurrentItem() + if idx >= 0 && idx < len(tables) { + showDbContentView(tables[idx]) + } + return nil + } + return event + }) + if len(tables) > 0 { + tblList.SetCurrentItem(0) + } + return flex +} + +func updateColumnsView(tableName string, tbl *tview.Table) { + columns, err := store.GetTableColumns(tableName) + if err != nil { + logger.Error("failed to get table columns", "table", tableName, "error", err) + return + } + tbl.Clear() + cols := 5 + tbl.SetFixed(1, 0) + for c := 0; c < cols; c++ { + color := tcell.ColorYellow + var headerText string + switch c { + case 0: + headerText = "CID" + case 1: + headerText = "Name" + case 2: + headerText = "Type" + case 3: + headerText = "NotNull" + case 4: + headerText = "PK" + } + tbl.SetCell(0, c, + tview.NewTableCell(headerText). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } + for r, col := range columns { + for c := 0; c < cols; c++ { + color := tcell.ColorWhite + if col.PK > 0 { + color = tcell.ColorRed + } + switch c { + case 0: + tbl.SetCell(r+1, c, + tview.NewTableCell(fmt.Sprintf("%d", col.CID)). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + case 1: + tbl.SetCell(r+1, c, + tview.NewTableCell(col.Name). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + case 2: + tbl.SetCell(r+1, c, + tview.NewTableCell(col.Type). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + case 3: + notNull := "N" + if col.NotNull { + notNull = "Y" + } + tbl.SetCell(r+1, c, + tview.NewTableCell(notNull). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + case 4: + pk := "" + if col.PK > 0 { + pk = fmt.Sprintf("%d", col.PK) + } + tbl.SetCell(r+1, c, + tview.NewTableCell(pk). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } + } + } + tbl.Select(0, 0) +} + +func showDbColumnsView(tableName, parentPage string) { + longStatusView := tview.NewTextView() + longStatusView.SetText("table: " + tableName + " | press x to exit | press Enter to view content").SetBorder(true).SetTitle("status") + longStatusView.SetChangedFunc(func() { + app.Draw() + }) + flex := tview.NewFlex().SetDirection(tview.FlexRow). + AddItem(longStatusView, 0, 10, false). + AddItem(tview.NewTable().SetBorders(true), 0, 60, true) + columns, err := store.GetTableColumns(tableName) + if err != nil { + logger.Error("failed to get table columns", "table", tableName, "error", err) + showToast("error", "failed to get columns: "+err.Error()) + return + } + tbl := flex.GetItem(1).(*tview.Table) + cols := 5 // CID | Name | Type | NotNull | PK + tbl.SetFixed(1, 0) + for c := 0; c < cols; c++ { + color := tcell.ColorYellow + var headerText string + switch c { + case 0: + headerText = "CID" + case 1: + headerText = "Name" + case 2: + headerText = "Type" + case 3: + headerText = "NotNull" + case 4: + headerText = "PK" + } + tbl.SetCell(0, c, + tview.NewTableCell(headerText). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } + for r, col := range columns { + for c := 0; c < cols; c++ { + color := tcell.ColorWhite + if col.PK > 0 { + color = tcell.ColorRed + } + switch c { + case 0: + tbl.SetCell(r+1, c, + tview.NewTableCell(fmt.Sprintf("%d", col.CID)). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + case 1: + tbl.SetCell(r+1, c, + tview.NewTableCell(col.Name). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + case 2: + tbl.SetCell(r+1, c, + tview.NewTableCell(col.Type). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + case 3: + notNull := "N" + if col.NotNull { + notNull = "Y" + } + tbl.SetCell(r+1, c, + tview.NewTableCell(notNull). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + case 4: + pk := "" + if col.PK > 0 { + pk = fmt.Sprintf("%d", col.PK) + } + tbl.SetCell(r+1, c, + tview.NewTableCell(pk). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } + } + } + columnsPageName := "dbColumns" + pages.AddPage(columnsPageName, flex, true, true) + flex.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyRune && event.Rune() == 'x' { + pages.RemovePage(columnsPageName) + return nil + } + if event.Key() == tcell.KeyEnter { + pages.RemovePage(columnsPageName) + showDbContentView(tableName) + } + return event + }) +} + +func showDbContentView(tableName string) { + batchSize := 80 + longStatusView := tview.NewTextView() + longStatusView.SetText("table: " + tableName + " | press Enter to load more").SetBorder(true).SetTitle("status") + longStatusView.SetChangedFunc(func() { + app.Draw() + }) + tbl := tview.NewTable().SetBorders(true).SetFixed(1, 0) + flex := tview.NewFlex().SetDirection(tview.FlexRow). + AddItem(longStatusView, 0, 10, false). + AddItem(tbl, 0, 60, true) + contentPageName := "db_content_" + tableName + offset := 0 + var rowCount int + _ = store.DB().Get(&rowCount, "SELECT COUNT(*) FROM "+tableName) + var columnNames []string + loadRows := func(off int) { + rows, err := store.DB().Queryx("SELECT * FROM " + tableName + " LIMIT " + fmt.Sprintf("%d", batchSize) + " OFFSET " + fmt.Sprintf("%d", off)) + if err != nil { + logger.Error("failed to query table", "table", tableName, "error", err) + return + } + if off == 0 { + columnNames, _ = rows.Columns() + for c, name := range columnNames { + tbl.SetCell(0, c, + tview.NewTableCell(name). + SetTextColor(tcell.ColorYellow). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } + } + r := off + for rows.Next() { + row := make(map[string]interface{}) + if err := rows.MapScan(row); err != nil { + logger.Error("failed to scan row", "error", err) + continue + } + for c, name := range columnNames { + val, ok := row[name] + if !ok { + tbl.SetCell(r+1, c, + tview.NewTableCell("NULL"). + SetTextColor(tcell.ColorDarkGray). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } else { + str := fmt.Sprintf("%v", val) + if len(str) > 50 { + str = str[:50] + "..." + } + tbl.SetCell(r+1, c, + tview.NewTableCell(str). + SetTextColor(tcell.ColorWhite). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } + } + r++ + } + rows.Close() + loaded := tbl.GetRowCount() - 1 + if loaded < rowCount { + longStatusView.SetText(fmt.Sprintf("table: %s | loaded %d of %d rows | press Enter for more", tableName, loaded, rowCount)) + } else { + longStatusView.SetText(fmt.Sprintf("table: %s | loaded %d rows (all)", tableName, loaded)) + } + } + loadRows(0) + pages.AddPage(contentPageName, flex, true, true) + flex.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyRune && event.Rune() == 'x' { + pages.RemovePage(contentPageName) + return nil + } + if event.Key() == tcell.KeyEnter { + offset += batchSize + loadRows(offset) + tbl.ScrollToEnd() + } + return event + }) +} diff --git a/tui.go b/tui.go index 7f897e4..9275867 100644 --- a/tui.go +++ b/tui.go @@ -51,6 +51,7 @@ var ( helpPage = "helpPage" renamePage = "renamePage" RAGPage = "RAGPage" + dbTablesPage = "dbTables" propsPage = "propsPage" codeBlockPage = "codeBlockPage" imgPage = "imgPage"