Feat: add character card support
This commit is contained in:
107
pngmeta/metareader.go
Normal file
107
pngmeta/metareader.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package pngmeta
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"elefant/models"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
embType = "tEXt"
|
||||
)
|
||||
|
||||
type PngEmbed struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
func (c PngEmbed) GetDecodedValue() (*models.CharCardSpec, error) {
|
||||
data, err := base64.StdEncoding.DecodeString(c.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
card := &models.CharCardSpec{}
|
||||
if err := json.Unmarshal(data, &card); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return card, nil
|
||||
}
|
||||
|
||||
func extractChar(fname string) (*PngEmbed, error) {
|
||||
data, err := os.ReadFile(fname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reader := bytes.NewReader(data)
|
||||
pr, err := NewPNGStepReader(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
step, err := pr.Next()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if step.Type() != embType {
|
||||
if _, err := io.Copy(io.Discard, step); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
buf, err := io.ReadAll(step)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dataInstep := string(buf)
|
||||
values := strings.Split(dataInstep, "\x00")
|
||||
if len(values) == 2 {
|
||||
return &PngEmbed{Key: values[0], Value: values[1]}, nil
|
||||
}
|
||||
}
|
||||
if err := step.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, errors.New("failed to find embedded char in png: " + fname)
|
||||
}
|
||||
|
||||
func ReadCard(fname, uname string) (*models.CharCard, error) {
|
||||
pe, err := extractChar(fname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
charSpec, err := pe.GetDecodedValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return charSpec.Simplify(uname, fname), nil
|
||||
}
|
||||
|
||||
func ReadDirCards(dirname, uname string) ([]*models.CharCard, error) {
|
||||
files, err := os.ReadDir(dirname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := []*models.CharCard{}
|
||||
for _, f := range files {
|
||||
if !strings.HasSuffix(f.Name(), ".png") {
|
||||
continue
|
||||
}
|
||||
fpath := path.Join(dirname, f.Name())
|
||||
cc, err := ReadCard(fpath, uname)
|
||||
if err != nil {
|
||||
// log err
|
||||
return nil, err
|
||||
// continue
|
||||
}
|
||||
resp = append(resp, cc)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
33
pngmeta/metareader_test.go
Normal file
33
pngmeta/metareader_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package pngmeta
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadMeta(t *testing.T) {
|
||||
cases := []struct {
|
||||
Filename string
|
||||
}{
|
||||
{
|
||||
Filename: "../sysprompts/default_Seraphina.png",
|
||||
},
|
||||
{
|
||||
Filename: "../sysprompts/llama.png",
|
||||
},
|
||||
}
|
||||
for i, tc := range cases {
|
||||
t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) {
|
||||
// Call the readMeta function
|
||||
pembed, err := extractChar(tc.Filename)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, but got %v", err)
|
||||
}
|
||||
v, err := pembed.GetDecodedValue()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, but got %v\n", err)
|
||||
}
|
||||
fmt.Printf("%+v\n", v.Simplify("Adam"))
|
||||
})
|
||||
}
|
||||
}
|
||||
77
pngmeta/partsreader.go
Normal file
77
pngmeta/partsreader.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package pngmeta
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"hash"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCRC32Mismatch = errors.New("crc32 mismatch")
|
||||
ErrNotPNG = errors.New("not png")
|
||||
ErrBadLength = errors.New("bad length")
|
||||
)
|
||||
|
||||
const header = "\x89PNG\r\n\x1a\n"
|
||||
|
||||
type PngChunk struct {
|
||||
typ string
|
||||
length int32
|
||||
r io.Reader
|
||||
realR io.Reader
|
||||
checksummer hash.Hash32
|
||||
}
|
||||
|
||||
func (c *PngChunk) Read(p []byte) (int, error) {
|
||||
return io.TeeReader(c.r, c.checksummer).Read(p)
|
||||
}
|
||||
|
||||
func (c *PngChunk) Close() error {
|
||||
var crc32 uint32
|
||||
if err := binary.Read(c.realR, binary.BigEndian, &crc32); err != nil {
|
||||
return err
|
||||
}
|
||||
if crc32 != c.checksummer.Sum32() {
|
||||
return ErrCRC32Mismatch
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PngChunk) Type() string {
|
||||
return c.typ
|
||||
}
|
||||
|
||||
type Reader struct {
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func NewPNGStepReader(r io.Reader) (*Reader, error) {
|
||||
expectedHeader := make([]byte, len(header))
|
||||
if _, err := io.ReadFull(r, expectedHeader); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if string(expectedHeader) != header {
|
||||
return nil, ErrNotPNG
|
||||
}
|
||||
return &Reader{r}, nil
|
||||
}
|
||||
|
||||
func (r *Reader) Next() (*PngChunk, error) {
|
||||
var length int32
|
||||
if err := binary.Read(r.r, binary.BigEndian, &length); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if length < 0 {
|
||||
return nil, ErrBadLength
|
||||
}
|
||||
var rawTyp [4]byte
|
||||
if _, err := io.ReadFull(r.r, rawTyp[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
typ := string(rawTyp[:])
|
||||
checksummer := crc32.NewIEEE()
|
||||
checksummer.Write([]byte(typ))
|
||||
return &PngChunk{typ, length, io.LimitReader(r.r, int64(length)), r.r, checksummer}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user