fix: support chutes embedding response

This commit is contained in:
shenlan 2025-08-14 20:50:27 +08:00
parent 5509234eff
commit 8596cb4b39
4 changed files with 39 additions and 14 deletions

View File

@ -90,6 +90,7 @@ func (s *StringSlice) UnmarshalYAML(value *yaml.Node) error {
type ModelCfg struct {
Provider string `yaml:"provider"`
Models StringSlice `yaml:"models"`
BaseURL string `yaml:"baseurl"`
Endpoint string `yaml:"endpoint"`
Token string `yaml:"token"`
}

View File

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
@ -45,12 +46,22 @@ func (c *Chutes) Embed(ctx context.Context, inputs []string) ([][]float32, int,
if resp.StatusCode >= 300 {
return nil, 0, &HTTPError{Code: resp.StatusCode, Status: fmt.Sprintf("embed failed: %s", resp.Status)}
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, 0, err
}
var out struct {
Data [][]float32 `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, 0, err
if err := json.Unmarshal(b, &out); err != nil || len(out.Data) == 0 {
if err := json.Unmarshal(b, &out.Data); err != nil {
return nil, 0, err
}
}
if len(out.Data) != len(inputs) {
return nil, 0, fmt.Errorf("embedding count mismatch")
}

View File

@ -8,18 +8,30 @@ import (
)
func TestChutesEmbed(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"data":[[0.1,0.2],[0.3,0.4]]}`))
}))
defer srv.Close()
emb := NewChutes(srv.URL, "", 0)
vecs, _, err := emb.Embed(context.Background(), []string{"a", "b"})
if err != nil {
t.Fatalf("Embed returned error: %v", err)
cases := []struct {
name string
response string
}{
{"object", `{"data":[[0.1,0.2],[0.3,0.4]]}`},
{"array", `[[0.1,0.2],[0.3,0.4]]`},
}
if len(vecs) != 2 || len(vecs[0]) != 2 {
t.Fatalf("unexpected embedding: %#v", vecs)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(tc.response))
}))
defer srv.Close()
emb := NewChutes(srv.URL, "", 0)
vecs, _, err := emb.Embed(context.Background(), []string{"a", "b"})
if err != nil {
t.Fatalf("Embed returned error: %v", err)
}
if len(vecs) != 2 || len(vecs[0]) != 2 {
t.Fatalf("unexpected embedding: %#v", vecs)
}
})
}
}

View File

@ -94,6 +94,7 @@ func (s *StringSlice) UnmarshalYAML(value *yaml.Node) error {
type ModelCfg struct {
Provider string `yaml:"provider"`
Models StringSlice `yaml:"models"`
BaseURL string `yaml:"baseurl"`
Endpoint string `yaml:"endpoint"`
Token string `yaml:"token"`
}