diff --git a/internal/rag/config/config.go b/internal/rag/config/config.go index 218627a..7f7d0ac 100644 --- a/internal/rag/config/config.go +++ b/internal/rag/config/config.go @@ -90,6 +90,7 @@ func (s *StringSlice) UnmarshalYAML(value *yaml.Node) error { type ModelCfg struct { Provider string `yaml:"provider"` Models StringSlice `yaml:"models"` + BaseURL string `yaml:"baseurl"` Endpoint string `yaml:"endpoint"` Token string `yaml:"token"` } diff --git a/internal/rag/embed/chutes.go b/internal/rag/embed/chutes.go index 0ec459d..0942d85 100644 --- a/internal/rag/embed/chutes.go +++ b/internal/rag/embed/chutes.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "time" ) @@ -45,12 +46,22 @@ func (c *Chutes) Embed(ctx context.Context, inputs []string) ([][]float32, int, if resp.StatusCode >= 300 { return nil, 0, &HTTPError{Code: resp.StatusCode, Status: fmt.Sprintf("embed failed: %s", resp.Status)} } + + b, err := io.ReadAll(resp.Body) + if err != nil { + return nil, 0, err + } + var out struct { Data [][]float32 `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { - return nil, 0, err + + if err := json.Unmarshal(b, &out); err != nil || len(out.Data) == 0 { + if err := json.Unmarshal(b, &out.Data); err != nil { + return nil, 0, err + } } + if len(out.Data) != len(inputs) { return nil, 0, fmt.Errorf("embedding count mismatch") } diff --git a/internal/rag/embed/chutes_test.go b/internal/rag/embed/chutes_test.go index 61f46f6..0b390f8 100644 --- a/internal/rag/embed/chutes_test.go +++ b/internal/rag/embed/chutes_test.go @@ -8,18 +8,30 @@ import ( ) func TestChutesEmbed(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"data":[[0.1,0.2],[0.3,0.4]]}`)) - })) - defer srv.Close() - - emb := NewChutes(srv.URL, "", 0) - vecs, _, err := emb.Embed(context.Background(), []string{"a", "b"}) - if err != nil { - t.Fatalf("Embed returned error: %v", err) + cases := []struct { + name string + response string + }{ + {"object", `{"data":[[0.1,0.2],[0.3,0.4]]}`}, + {"array", `[[0.1,0.2],[0.3,0.4]]`}, } - if len(vecs) != 2 || len(vecs[0]) != 2 { - t.Fatalf("unexpected embedding: %#v", vecs) + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(tc.response)) + })) + defer srv.Close() + + emb := NewChutes(srv.URL, "", 0) + vecs, _, err := emb.Embed(context.Background(), []string{"a", "b"}) + if err != nil { + t.Fatalf("Embed returned error: %v", err) + } + if len(vecs) != 2 || len(vecs[0]) != 2 { + t.Fatalf("unexpected embedding: %#v", vecs) + } + }) } } diff --git a/server/config/config.go b/server/config/config.go index e25e7b5..cba5e6a 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -94,6 +94,7 @@ func (s *StringSlice) UnmarshalYAML(value *yaml.Node) error { type ModelCfg struct { Provider string `yaml:"provider"` Models StringSlice `yaml:"models"` + BaseURL string `yaml:"baseurl"` Endpoint string `yaml:"endpoint"` Token string `yaml:"token"` }