fix: support chutes embedding response
This commit is contained in:
parent
5509234eff
commit
8596cb4b39
@ -90,6 +90,7 @@ func (s *StringSlice) UnmarshalYAML(value *yaml.Node) error {
|
||||
type ModelCfg struct {
|
||||
Provider string `yaml:"provider"`
|
||||
Models StringSlice `yaml:"models"`
|
||||
BaseURL string `yaml:"baseurl"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Token string `yaml:"token"`
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
@ -45,12 +46,22 @@ func (c *Chutes) Embed(ctx context.Context, inputs []string) ([][]float32, int,
|
||||
if resp.StatusCode >= 300 {
|
||||
return nil, 0, &HTTPError{Code: resp.StatusCode, Status: fmt.Sprintf("embed failed: %s", resp.Status)}
|
||||
}
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var out struct {
|
||||
Data [][]float32 `json:"data"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return nil, 0, err
|
||||
|
||||
if err := json.Unmarshal(b, &out); err != nil || len(out.Data) == 0 {
|
||||
if err := json.Unmarshal(b, &out.Data); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(out.Data) != len(inputs) {
|
||||
return nil, 0, fmt.Errorf("embedding count mismatch")
|
||||
}
|
||||
|
||||
@ -8,18 +8,30 @@ import (
|
||||
)
|
||||
|
||||
func TestChutesEmbed(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"data":[[0.1,0.2],[0.3,0.4]]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
emb := NewChutes(srv.URL, "", 0)
|
||||
vecs, _, err := emb.Embed(context.Background(), []string{"a", "b"})
|
||||
if err != nil {
|
||||
t.Fatalf("Embed returned error: %v", err)
|
||||
cases := []struct {
|
||||
name string
|
||||
response string
|
||||
}{
|
||||
{"object", `{"data":[[0.1,0.2],[0.3,0.4]]}`},
|
||||
{"array", `[[0.1,0.2],[0.3,0.4]]`},
|
||||
}
|
||||
if len(vecs) != 2 || len(vecs[0]) != 2 {
|
||||
t.Fatalf("unexpected embedding: %#v", vecs)
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(tc.response))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
emb := NewChutes(srv.URL, "", 0)
|
||||
vecs, _, err := emb.Embed(context.Background(), []string{"a", "b"})
|
||||
if err != nil {
|
||||
t.Fatalf("Embed returned error: %v", err)
|
||||
}
|
||||
if len(vecs) != 2 || len(vecs[0]) != 2 {
|
||||
t.Fatalf("unexpected embedding: %#v", vecs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,6 +94,7 @@ func (s *StringSlice) UnmarshalYAML(value *yaml.Node) error {
|
||||
type ModelCfg struct {
|
||||
Provider string `yaml:"provider"`
|
||||
Models StringSlice `yaml:"models"`
|
||||
BaseURL string `yaml:"baseurl"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Token string `yaml:"token"`
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user