Revert "refactor: add native chutes embeddings client"
This commit is contained in:
parent
122a6128df
commit
b645b18f64
@ -68,7 +68,19 @@ var rootCmd = &cobra.Command{
|
||||
embCfg := cfg.ResolveEmbedding()
|
||||
chunkCfg := cfg.ResolveChunking()
|
||||
|
||||
embedder := embed.NewClient(embCfg.Provider, embCfg.BaseURL, embCfg.APIKey, embCfg.Model)
|
||||
var embedder embed.Embedder
|
||||
switch embCfg.Provider {
|
||||
case "ollama":
|
||||
embedder = embed.NewOllama(embCfg.Endpoint, embCfg.Model, embCfg.Dimension)
|
||||
case "chutes":
|
||||
embedder = embed.NewOpenAI(embCfg.Endpoint, embCfg.APIKey, "", embCfg.Dimension)
|
||||
default:
|
||||
if embCfg.Model != "" {
|
||||
embedder = embed.NewOpenAI(embCfg.Endpoint, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
} else {
|
||||
embedder = embed.NewBGE(embCfg.Endpoint, embCfg.APIKey, embCfg.Dimension)
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := os.Getenv("SERVER_URL")
|
||||
if baseURL == "" {
|
||||
@ -129,7 +141,7 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func ingestFile(ctx context.Context, cfg *rconfig.Config, chunkCfg rconfig.ChunkingCfg, embedder embed.Client, baseURL, filePath string) error {
|
||||
func ingestFile(ctx context.Context, cfg *rconfig.Config, chunkCfg rconfig.ChunkingCfg, embedder embed.Embedder, baseURL, filePath string) error {
|
||||
var ds *rconfig.DataSource
|
||||
var workdir string
|
||||
for i := range cfg.Global.Datasources {
|
||||
@ -166,7 +178,7 @@ func ingestFile(ctx context.Context, cfg *rconfig.Config, chunkCfg rconfig.Chunk
|
||||
ContentSHA: ch.SHA256,
|
||||
}
|
||||
}
|
||||
vecs, err := embedder.Embed(ctx, texts)
|
||||
vecs, _, err := embedder.Embed(ctx, texts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("embed %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
@ -84,30 +84,26 @@ For local debugging with HuggingFace and Ollama:
|
||||
models:
|
||||
embedder:
|
||||
models: "bge-m3"
|
||||
baseurl: "http://127.0.0.1:9000"
|
||||
endpoint: "http://127.0.0.1:9000/v1/embeddings"
|
||||
generator:
|
||||
models:
|
||||
- 'llama2:13b'
|
||||
baseurl: "http://127.0.0.1:11434"
|
||||
endpoint: "http://127.0.0.1:11434"
|
||||
```
|
||||
|
||||
For online services using Chutes, set the provider to `chutes` and point the
|
||||
embedder base URL directly at the model host. The client automatically posts to
|
||||
`/embed` with the correct JSON payload:
|
||||
For online services using Chutes:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
embedder:
|
||||
provider: "chutes"
|
||||
models: "bge-m3"
|
||||
baseurl: "https://chutes-baai-bge-m3.chutes.ai"
|
||||
token: "cpk_xxxx"
|
||||
generator:
|
||||
provider: "chutes"
|
||||
models:
|
||||
- 'moonshotai/Kimi-K2-Instruct'
|
||||
baseurl: "https://llm.chutes.ai/v1"
|
||||
token: "cpk_xxxx"
|
||||
#models:
|
||||
# embedder:
|
||||
# models: "bge-m3"
|
||||
# endpoint: "https://chutes-baai-bge-m3.chutes.ai/embed"
|
||||
# token: "cpk_xxxx"
|
||||
# generator:
|
||||
# models:
|
||||
# - 'moonshotai/Kimi-K2-Instruct'
|
||||
# endpoint: "https://llm.chutes.ai/v1"
|
||||
# token: "cpk_xxxx"
|
||||
```
|
||||
|
||||
The `api.askai` section controls request behaviour:
|
||||
|
||||
@ -22,12 +22,12 @@ models:
|
||||
embedder:
|
||||
provider: "huggingface_hub"
|
||||
models: "bge-m3"
|
||||
baseurl: "http://127.0.0.1:9000"
|
||||
endpoint: "http://127.0.0.1:9000/v1/embeddings"
|
||||
generator:
|
||||
provider: "ollama"
|
||||
models:
|
||||
- 'gemma3:4b'
|
||||
baseurl: "http://127.0.0.1:11434"
|
||||
endpoint: "http://127.0.0.1:11434/v1/chat/completions"
|
||||
|
||||
embedding:
|
||||
max_batch: 64
|
||||
|
||||
@ -20,17 +20,17 @@ sync:
|
||||
|
||||
provider:
|
||||
- name: ollama
|
||||
baseurl: http://localhost:11434
|
||||
endpoint: http://localhost:11434
|
||||
models:
|
||||
- 'gpt-oss:20b'
|
||||
- name: chutes
|
||||
baseurl: https://llm.chutes.ai/v1
|
||||
endpoint: https://llm.chutes.ai/v1
|
||||
token: "cpk_xxxxxxxxxxxxxxxxxx"
|
||||
models:
|
||||
- 'moonshotai/Kimi-K2-Instruct'
|
||||
|
||||
embedding:
|
||||
baseurl: https://chutes-baai-bge-m3.chutes.ai
|
||||
endpoint: https://chutes-baai-bge-m3.chutes.ai/embed/v1/embeddings
|
||||
token: "cpk_xxxxxxxxxxxxxxxxxx"
|
||||
dimension: 0 # 0 = 首次响应自动探测维度
|
||||
rate_limit_tpm: 120000
|
||||
|
||||
@ -90,7 +90,7 @@ func (s *StringSlice) UnmarshalYAML(value *yaml.Node) error {
|
||||
type ModelCfg struct {
|
||||
Provider string `yaml:"provider"`
|
||||
Models StringSlice `yaml:"models"`
|
||||
BaseURL string `yaml:"baseurl"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Token string `yaml:"token"`
|
||||
}
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ import (
|
||||
// RuntimeEmbedding is the resolved embedding configuration used at runtime.
|
||||
type RuntimeEmbedding struct {
|
||||
Provider string
|
||||
BaseURL string
|
||||
Endpoint string
|
||||
APIKey string
|
||||
Model string
|
||||
Dimension int
|
||||
@ -26,7 +26,7 @@ func (c *Config) ResolveEmbedding() RuntimeEmbedding {
|
||||
if len(m.Models) > 0 {
|
||||
rt.Model = m.Models[0]
|
||||
}
|
||||
rt.BaseURL = strings.TrimRight(m.BaseURL, "/")
|
||||
rt.Endpoint = strings.TrimRight(m.Endpoint, "/")
|
||||
rt.APIKey = m.Token
|
||||
|
||||
e := c.Embedding
|
||||
@ -128,7 +128,7 @@ func (rt *Runtime) ToConfig() *Config {
|
||||
c.Global.Datasources = rt.Datasources
|
||||
c.Global.Proxy = rt.Proxy
|
||||
c.Models.Embedder.Provider = rt.Embedding.Provider
|
||||
c.Models.Embedder.BaseURL = rt.Embedding.BaseURL
|
||||
c.Models.Embedder.Endpoint = rt.Embedding.Endpoint
|
||||
c.Models.Embedder.Token = rt.Embedding.APIKey
|
||||
if rt.Embedding.Model != "" {
|
||||
c.Models.Embedder.Models = []string{rt.Embedding.Model}
|
||||
|
||||
@ -21,12 +21,12 @@ func TestVectorDB_DSN(t *testing.T) {
|
||||
func TestResolveEmbedding(t *testing.T) {
|
||||
cfg := &Config{}
|
||||
cfg.Models.Embedder.Provider = "p1"
|
||||
cfg.Models.Embedder.BaseURL = "https://api.example.com"
|
||||
cfg.Models.Embedder.Endpoint = "https://api.example.com"
|
||||
cfg.Models.Embedder.Token = "tok"
|
||||
cfg.Models.Embedder.Models = []string{"m"}
|
||||
e := cfg.ResolveEmbedding()
|
||||
if e.BaseURL != "https://api.example.com" {
|
||||
t.Fatalf("unexpected base url %q", e.BaseURL)
|
||||
if e.Endpoint != "https://api.example.com" {
|
||||
t.Fatalf("unexpected endpoint %q", e.Endpoint)
|
||||
}
|
||||
if e.APIKey != "tok" {
|
||||
t.Fatalf("unexpected api key %q", e.APIKey)
|
||||
@ -49,12 +49,12 @@ func TestResolveChunking(t *testing.T) {
|
||||
|
||||
func TestRuntimeToConfigEmbedding(t *testing.T) {
|
||||
rt := &Runtime{}
|
||||
rt.Embedding.BaseURL = "http://localhost:8080"
|
||||
rt.Embedding.Endpoint = "http://localhost:8080"
|
||||
rt.Embedding.APIKey = "tok"
|
||||
rt.Embedding.Dimension = 123
|
||||
cfg := rt.ToConfig()
|
||||
if cfg.Models.Embedder.BaseURL != "http://localhost:8080" {
|
||||
t.Fatalf("unexpected base url %q", cfg.Models.Embedder.BaseURL)
|
||||
if cfg.Models.Embedder.Endpoint != "http://localhost:8080" {
|
||||
t.Fatalf("unexpected base url %q", cfg.Models.Embedder.Endpoint)
|
||||
}
|
||||
if cfg.Models.Embedder.Token != "tok" {
|
||||
t.Fatalf("unexpected token %q", cfg.Models.Embedder.Token)
|
||||
|
||||
85
internal/rag/embed/bge.go
Normal file
85
internal/rag/embed/bge.go
Normal file
@ -0,0 +1,85 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BGE implements the Embedder interface for a BGE embedding service.
|
||||
type BGE struct {
|
||||
endpoint string
|
||||
token string
|
||||
dim int
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewBGE returns a new BGE embedder.
|
||||
func NewBGE(endpoint, token string, dim int) *BGE {
|
||||
return &BGE{
|
||||
endpoint: endpoint,
|
||||
token: token,
|
||||
dim: dim,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// Dimension returns the embedding dimension if known.
|
||||
func (b *BGE) Dimension() int { return b.dim }
|
||||
|
||||
// Embed posts texts to the BGE service and returns embeddings.
|
||||
func (b *BGE) Embed(ctx context.Context, inputs []string) ([][]float32, int, error) {
|
||||
vecs := make([][]float32, len(inputs))
|
||||
for i, text := range inputs {
|
||||
payload := map[string]any{"inputs": text}
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, b.endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if b.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+b.token)
|
||||
}
|
||||
resp, err := b.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
resp.Body.Close()
|
||||
return nil, 0, &HTTPError{Code: resp.StatusCode, Status: fmt.Sprintf("embed failed: %s", resp.Status)}
|
||||
}
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var out struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &out); err != nil || len(out.Embedding) == 0 {
|
||||
// try raw array format
|
||||
var arr []float32
|
||||
if err := json.Unmarshal(data, &arr); err != nil || len(arr) == 0 {
|
||||
// some services return [[..]] even for single input
|
||||
var arr2 [][]float32
|
||||
if err := json.Unmarshal(data, &arr2); err != nil || len(arr2) == 0 {
|
||||
return nil, 0, err
|
||||
}
|
||||
arr = arr2[0]
|
||||
}
|
||||
out.Embedding = arr
|
||||
}
|
||||
|
||||
if b.dim == 0 {
|
||||
b.dim = len(out.Embedding)
|
||||
}
|
||||
vecs[i] = out.Embedding
|
||||
}
|
||||
return vecs, 0, nil
|
||||
}
|
||||
59
internal/rag/embed/bge_test.go
Normal file
59
internal/rag/embed/bge_test.go
Normal file
@ -0,0 +1,59 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBGEEmbedArray(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`[0.1,0.2]`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
emb := NewBGE(srv.URL, "", 0)
|
||||
vecs, _, err := emb.Embed(context.Background(), []string{"foo"})
|
||||
if err != nil {
|
||||
t.Fatalf("Embed returned error: %v", err)
|
||||
}
|
||||
if len(vecs) != 1 || len(vecs[0]) != 2 {
|
||||
t.Fatalf("unexpected embedding: %#v", vecs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBGEEmbedObject(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"embedding":[0.3,0.4]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
emb := NewBGE(srv.URL, "", 0)
|
||||
vecs, _, err := emb.Embed(context.Background(), []string{"bar"})
|
||||
if err != nil {
|
||||
t.Fatalf("Embed returned error: %v", err)
|
||||
}
|
||||
if len(vecs) != 1 || len(vecs[0]) != 2 {
|
||||
t.Fatalf("unexpected embedding: %#v", vecs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBGEEmbedNestedArray(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`[[0.5,0.6]]`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
emb := NewBGE(srv.URL, "", 0)
|
||||
vecs, _, err := emb.Embed(context.Background(), []string{"baz"})
|
||||
if err != nil {
|
||||
t.Fatalf("Embed returned error: %v", err)
|
||||
}
|
||||
if len(vecs) != 1 || len(vecs[0]) != 2 {
|
||||
t.Fatalf("unexpected embedding: %#v", vecs)
|
||||
}
|
||||
}
|
||||
@ -1,93 +0,0 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ChutesClient implements embeddings for the Chutes API.
|
||||
type ChutesClient struct {
|
||||
endpoint string
|
||||
token string
|
||||
model string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewChutesClient returns a new Chutes client.
|
||||
func NewChutesClient(endpoint, token, model string) *ChutesClient {
|
||||
return &ChutesClient{
|
||||
endpoint: endpoint,
|
||||
token: token,
|
||||
model: model,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// Embed posts the inputs to the Chutes /embed endpoint.
|
||||
func (c *ChutesClient) Embed(ctx context.Context, inputs []string) ([][]float32, error) {
|
||||
payload := map[string]any{"inputs": inputs}
|
||||
if c.model != "" {
|
||||
payload["model"] = c.model
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
backoff := time.Second
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
body := bytes.NewReader(b)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.token)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
if attempt == 2 {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if resp.StatusCode == 429 || resp.StatusCode >= 500 {
|
||||
resp.Body.Close()
|
||||
} else if resp.StatusCode >= 400 {
|
||||
resp.Body.Close()
|
||||
return nil, &HTTPError{Code: resp.StatusCode, Status: fmt.Sprintf("embed failed: %s", resp.Status)}
|
||||
} else {
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Try Chutes format
|
||||
var chutes struct {
|
||||
Data [][]float32 `json:"data"`
|
||||
}
|
||||
if json.Unmarshal(data, &chutes) == nil && len(chutes.Data) > 0 {
|
||||
return chutes.Data, nil
|
||||
}
|
||||
// Try OpenAI format
|
||||
var openai struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if json.Unmarshal(data, &openai) == nil && len(openai.Data) > 0 {
|
||||
vecs := make([][]float32, len(openai.Data))
|
||||
for i, d := range openai.Data {
|
||||
vecs[i] = d.Embedding
|
||||
}
|
||||
return vecs, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected embed response")
|
||||
}
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
backoff *= 2
|
||||
}
|
||||
return nil, fmt.Errorf("embed failed after retries")
|
||||
}
|
||||
@ -2,8 +2,8 @@ package embed
|
||||
|
||||
import "context"
|
||||
|
||||
// Client defines embedding operations for various providers.
|
||||
type Client interface {
|
||||
// Embed converts input texts into embedding vectors.
|
||||
Embed(ctx context.Context, inputs []string) ([][]float32, error)
|
||||
// Embedder defines embedding operations.
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, inputs []string) ([][]float32, int, error)
|
||||
Dimension() int
|
||||
}
|
||||
|
||||
@ -1,32 +0,0 @@
|
||||
package embed
|
||||
|
||||
import "strings"
|
||||
|
||||
// NewClient creates an embeddings client based on provider or base URL.
|
||||
func NewClient(provider, baseURL, token, model string) Client {
|
||||
p := strings.ToLower(provider)
|
||||
if p == "" {
|
||||
if strings.Contains(strings.ToLower(baseURL), "chutes") {
|
||||
p = "chutes"
|
||||
} else if strings.Contains(strings.ToLower(baseURL), "ollama") {
|
||||
p = "ollama"
|
||||
} else {
|
||||
p = "openai"
|
||||
}
|
||||
}
|
||||
endpoint := strings.TrimRight(baseURL, "/")
|
||||
switch p {
|
||||
case "chutes":
|
||||
if !strings.HasSuffix(endpoint, "/embed") {
|
||||
endpoint += "/embed"
|
||||
}
|
||||
return NewChutesClient(endpoint, token, model)
|
||||
case "ollama":
|
||||
return NewOllamaClient(endpoint, model)
|
||||
default:
|
||||
if !strings.HasSuffix(endpoint, "/v1/embeddings") {
|
||||
endpoint += "/v1/embeddings"
|
||||
}
|
||||
return NewOpenAIClient(endpoint, token, model)
|
||||
}
|
||||
}
|
||||
@ -11,24 +11,29 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// OllamaClient implements embeddings using the Ollama API.
|
||||
type OllamaClient struct {
|
||||
// Ollama implements the Embedder interface using the Ollama embeddings API.
|
||||
type Ollama struct {
|
||||
endpoint string
|
||||
model string
|
||||
dim int
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewOllamaClient returns a new Ollama client.
|
||||
func NewOllamaClient(endpoint, model string) *OllamaClient {
|
||||
return &OllamaClient{
|
||||
// NewOllama creates a new Ollama embedder.
|
||||
func NewOllama(endpoint, model string, dim int) *Ollama {
|
||||
return &Ollama{
|
||||
endpoint: strings.TrimRight(endpoint, "/"),
|
||||
model: model,
|
||||
dim: dim,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// Dimension returns the embedding dimension if known.
|
||||
func (a *Ollama) Dimension() int { return a.dim }
|
||||
|
||||
// Embed posts texts to the Ollama embeddings endpoint.
|
||||
func (a *OllamaClient) Embed(ctx context.Context, inputs []string) ([][]float32, error) {
|
||||
func (a *Ollama) Embed(ctx context.Context, inputs []string) ([][]float32, int, error) {
|
||||
vecs := make([][]float32, len(inputs))
|
||||
url := a.endpoint
|
||||
for i, text := range inputs {
|
||||
@ -36,27 +41,30 @@ func (a *OllamaClient) Embed(ctx context.Context, inputs []string) ([][]float32,
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
resp.Body.Close()
|
||||
return nil, &HTTPError{Code: resp.StatusCode, Status: fmt.Sprintf("embed failed: %s", resp.Status)}
|
||||
return nil, 0, &HTTPError{Code: resp.StatusCode, Status: fmt.Sprintf("embed failed: %s", resp.Status)}
|
||||
}
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
var out struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
if a.dim == 0 {
|
||||
a.dim = len(out.Embedding)
|
||||
}
|
||||
vec := make([]float32, len(out.Embedding))
|
||||
for j, v := range out.Embedding {
|
||||
@ -64,5 +72,5 @@ func (a *OllamaClient) Embed(ctx context.Context, inputs []string) ([][]float32,
|
||||
}
|
||||
vecs[i] = vec
|
||||
}
|
||||
return vecs, nil
|
||||
return vecs, 0, nil
|
||||
}
|
||||
|
||||
@ -4,90 +4,78 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OpenAIClient implements embeddings against OpenAI-compatible services.
|
||||
type OpenAIClient struct {
|
||||
// OpenAI implements the Embedder interface using OpenAI-compatible APIs.
|
||||
type OpenAI struct {
|
||||
endpoint string
|
||||
apiKey string
|
||||
model string
|
||||
dim int
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewOpenAIClient returns a new OpenAI client.
|
||||
func NewOpenAIClient(endpoint, apiKey, model string) *OpenAIClient {
|
||||
return &OpenAIClient{
|
||||
// NewOpenAI creates a new OpenAI embedder from configuration.
|
||||
func NewOpenAI(endpoint, apiKey, model string, dim int) *OpenAI {
|
||||
return &OpenAI{
|
||||
endpoint: endpoint,
|
||||
apiKey: apiKey,
|
||||
model: model,
|
||||
dim: dim,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// Embed posts the inputs to the OpenAI embeddings endpoint.
|
||||
func (c *OpenAIClient) Embed(ctx context.Context, inputs []string) ([][]float32, error) {
|
||||
// Dimension returns the embedding dimension if known.
|
||||
func (o *OpenAI) Dimension() int { return o.dim }
|
||||
|
||||
// Embed embeds the inputs and returns vectors and token usage.
|
||||
func (o *OpenAI) Embed(ctx context.Context, inputs []string) ([][]float32, int, error) {
|
||||
payload := map[string]any{"input": inputs}
|
||||
if c.model != "" {
|
||||
payload["model"] = c.model
|
||||
if o.model != "" {
|
||||
payload["model"] = o.model
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
backoff := time.Second
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
body := bytes.NewReader(b)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
if attempt == 2 {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if resp.StatusCode == 429 || resp.StatusCode >= 500 {
|
||||
resp.Body.Close()
|
||||
} else if resp.StatusCode >= 400 {
|
||||
resp.Body.Close()
|
||||
return nil, &HTTPError{Code: resp.StatusCode, Status: fmt.Sprintf("embed failed: %s", resp.Status)}
|
||||
} else {
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Try OpenAI format
|
||||
var openai struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if json.Unmarshal(data, &openai) == nil && len(openai.Data) > 0 {
|
||||
vecs := make([][]float32, len(openai.Data))
|
||||
for i, d := range openai.Data {
|
||||
vecs[i] = d.Embedding
|
||||
}
|
||||
return vecs, nil
|
||||
}
|
||||
// Try Chutes format
|
||||
var chutes struct {
|
||||
Data [][]float32 `json:"data"`
|
||||
}
|
||||
if json.Unmarshal(data, &chutes) == nil && len(chutes.Data) > 0 {
|
||||
return chutes.Data, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected embed response")
|
||||
}
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
backoff *= 2
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, o.endpoint, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return nil, fmt.Errorf("embed failed after retries")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if o.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+o.apiKey)
|
||||
}
|
||||
resp, err := o.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
return nil, 0, &HTTPError{Code: resp.StatusCode, Status: fmt.Sprintf("embed failed: %s", resp.Status)}
|
||||
}
|
||||
var out struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if len(out.Data) != len(inputs) {
|
||||
return nil, 0, errors.New("embedding count mismatch")
|
||||
}
|
||||
if o.dim == 0 && len(out.Data) > 0 {
|
||||
o.dim = len(out.Data[0].Embedding)
|
||||
}
|
||||
vecs := make([][]float32, len(out.Data))
|
||||
for i, d := range out.Data {
|
||||
vecs[i] = d.Embedding
|
||||
}
|
||||
return vecs, out.Usage.TotalTokens, nil
|
||||
}
|
||||
|
||||
@ -67,8 +67,20 @@ func IngestRepo(ctx context.Context, cfg *cfgpkg.Config, ds cfgpkg.DataSource, o
|
||||
}
|
||||
defer conn.Close(ctx)
|
||||
|
||||
embedder := embed.NewClient(embCfg.Provider, embCfg.BaseURL, embCfg.APIKey, embCfg.Model)
|
||||
if err := store.EnsureSchema(ctx, conn, embCfg.Dimension, opt.MigrateDim); err != nil {
|
||||
var embedder embed.Embedder
|
||||
switch embCfg.Provider {
|
||||
case "ollama":
|
||||
embedder = embed.NewOllama(embCfg.Endpoint, embCfg.Model, embCfg.Dimension)
|
||||
case "chutes":
|
||||
embedder = embed.NewOpenAI(embCfg.Endpoint, embCfg.APIKey, "", embCfg.Dimension)
|
||||
default:
|
||||
if embCfg.Model != "" {
|
||||
embedder = embed.NewOpenAI(embCfg.Endpoint, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
} else {
|
||||
embedder = embed.NewBGE(embCfg.Endpoint, embCfg.APIKey, embCfg.Dimension)
|
||||
}
|
||||
}
|
||||
if err := store.EnsureSchema(ctx, conn, embedder.Dimension(), opt.MigrateDim); err != nil {
|
||||
st.Errors = append(st.Errors, err)
|
||||
return st, err
|
||||
}
|
||||
@ -101,12 +113,13 @@ func IngestRepo(ctx context.Context, cfg *cfgpkg.Config, ds cfgpkg.DataSource, o
|
||||
ContentSHA: ch.SHA256,
|
||||
}
|
||||
}
|
||||
vecs, err := embedder.Embed(ctx, texts)
|
||||
vecs, tokens, err := embedder.Embed(ctx, texts)
|
||||
if err != nil {
|
||||
st.Errors = append(st.Errors, err)
|
||||
continue
|
||||
}
|
||||
st.EmbeddingsCreated += len(vecs)
|
||||
st.TokensEstimated += tokens
|
||||
for i := range rows {
|
||||
rows[i].Embedding = vecs[i]
|
||||
}
|
||||
|
||||
@ -59,11 +59,23 @@ func (s *Service) Query(ctx context.Context, question string, limit int) ([]Docu
|
||||
return nil, nil
|
||||
}
|
||||
embCfg := s.cfg.ResolveEmbedding()
|
||||
if embCfg.BaseURL == "" {
|
||||
if embCfg.Endpoint == "" {
|
||||
return nil, nil
|
||||
}
|
||||
emb := embed.NewClient(embCfg.Provider, embCfg.BaseURL, embCfg.APIKey, embCfg.Model)
|
||||
vecs, err := emb.Embed(ctx, []string{question})
|
||||
var emb embed.Embedder
|
||||
switch embCfg.Provider {
|
||||
case "ollama":
|
||||
emb = embed.NewOllama(embCfg.Endpoint, embCfg.Model, embCfg.Dimension)
|
||||
case "chutes":
|
||||
emb = embed.NewOpenAI(embCfg.Endpoint, embCfg.APIKey, "", embCfg.Dimension)
|
||||
default:
|
||||
if embCfg.Model != "" {
|
||||
emb = embed.NewOpenAI(embCfg.Endpoint, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
} else {
|
||||
emb = embed.NewBGE(embCfg.Endpoint, embCfg.APIKey, embCfg.Dimension)
|
||||
}
|
||||
}
|
||||
vecs, _, err := emb.Embed(ctx, []string{question})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -159,8 +171,8 @@ func (s *Service) Query(ctx context.Context, question string, limit int) ([]Docu
|
||||
// optional reranking
|
||||
var rr rerank.Reranker
|
||||
rCfg := s.cfg.Models.Reranker
|
||||
if rCfg.BaseURL != "" {
|
||||
rr = rerank.NewBGE(rCfg.BaseURL, rCfg.Token)
|
||||
if rCfg.Endpoint != "" {
|
||||
rr = rerank.NewBGE(rCfg.Endpoint, rCfg.Token)
|
||||
}
|
||||
if rr != nil {
|
||||
docs := make([]string, len(candidates))
|
||||
|
||||
@ -68,9 +68,9 @@ var ConfigPath = filepath.Join("server", "config", "server.yaml")
|
||||
type serverConfig struct {
|
||||
Models struct {
|
||||
Generator struct {
|
||||
Models []string `yaml:"models"`
|
||||
BaseURL string `yaml:"baseurl"`
|
||||
Token string `yaml:"token"`
|
||||
Models []string `yaml:"models"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Token string `yaml:"token"`
|
||||
} `yaml:"generator"`
|
||||
} `yaml:"models"`
|
||||
API struct {
|
||||
@ -81,11 +81,11 @@ type serverConfig struct {
|
||||
} `yaml:"api"`
|
||||
}
|
||||
|
||||
// loadConfig reads model, base URL, timeout and retries from ConfigPath
|
||||
// loadConfig reads model, endpoint, timeout and retries from ConfigPath
|
||||
// and environment variables.
|
||||
func loadConfig() (string, string, string, time.Duration, int) {
|
||||
model := os.Getenv("CHUTES_API_MODEL")
|
||||
baseURL := os.Getenv("CHUTES_API_URL")
|
||||
endpoint := os.Getenv("CHUTES_API_URL")
|
||||
token := ""
|
||||
timeout := 30 * time.Second
|
||||
retries := 3
|
||||
@ -97,8 +97,8 @@ func loadConfig() (string, string, string, time.Duration, int) {
|
||||
if model == "" && len(g.Models) > 0 {
|
||||
model = g.Models[0]
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = g.BaseURL
|
||||
if endpoint == "" {
|
||||
endpoint = g.Endpoint
|
||||
}
|
||||
if token == "" {
|
||||
token = g.Token
|
||||
@ -115,15 +115,7 @@ func loadConfig() (string, string, string, time.Duration, int) {
|
||||
if retries > 3 {
|
||||
retries = 3
|
||||
}
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
endpoint := baseURL
|
||||
if !strings.HasSuffix(endpoint, "/chat/completions") {
|
||||
if strings.HasSuffix(endpoint, "/v1") {
|
||||
endpoint += "/chat/completions"
|
||||
} else {
|
||||
endpoint += "/v1/chat/completions"
|
||||
}
|
||||
}
|
||||
endpoint = strings.TrimRight(endpoint, "/")
|
||||
if model == "" {
|
||||
if token == "" || strings.Contains(endpoint, "127.0.0.1") || strings.Contains(endpoint, "localhost") {
|
||||
model = "llama2:13b"
|
||||
|
||||
@ -11,7 +11,7 @@ import (
|
||||
func TestLoadConfig_FromFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "server.yaml")
|
||||
data := []byte("models:\n generator:\n models: [\"llama2:13b\"]\n baseurl: http://localhost:11434\n token: t1\napi:\n askai:\n timeout: 10\n retries: 2\n")
|
||||
data := []byte("models:\n generator:\n models: [\"llama2:13b\"]\n endpoint: http://localhost:11434/v1/chat/completions\n token: t1\napi:\n askai:\n timeout: 10\n retries: 2\n")
|
||||
if err := os.WriteFile(cfgPath, data, 0o644); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
@ -42,7 +42,7 @@ func TestLoadConfig_FromFile(t *testing.T) {
|
||||
func TestLoadConfig_EnvOverrides(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "server.yaml")
|
||||
data := []byte("models:\n generator:\n models: [\"llama2:13b\"]\n baseurl: http://localhost:11434\napi:\n askai:\n timeout: 50\n retries: 5\n")
|
||||
data := []byte("models:\n generator:\n models: [\"llama2:13b\"]\n endpoint: http://localhost:11434/v1/chat/completions\napi:\n askai:\n timeout: 50\n retries: 5\n")
|
||||
if err := os.WriteFile(cfgPath, data, 0o644); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
@ -94,7 +94,7 @@ func (s *StringSlice) UnmarshalYAML(value *yaml.Node) error {
|
||||
type ModelCfg struct {
|
||||
Provider string `yaml:"provider"`
|
||||
Models StringSlice `yaml:"models"`
|
||||
BaseURL string `yaml:"baseurl"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Token string `yaml:"token"`
|
||||
}
|
||||
|
||||
|
||||
@ -1,51 +0,0 @@
|
||||
global:
|
||||
redis:
|
||||
addr: "127.0.0.1:6379"
|
||||
password: ""
|
||||
vectordb:
|
||||
pgurl: "postgres://shenlan:password@127.0.0.1:5432/shenlan"
|
||||
datasources:
|
||||
- name: Xstream
|
||||
repo: https://github.com/svc-design/Xstream
|
||||
path: docs
|
||||
- name: XControl
|
||||
repo: https://github.com/svc-design/XControl
|
||||
path: docs
|
||||
- name: documents
|
||||
repo: https://github.com/svc-design/documents
|
||||
path: /
|
||||
sync:
|
||||
repo:
|
||||
proxy: socks5://127.0.0.1:1080
|
||||
|
||||
models:
|
||||
embedder:
|
||||
provider: "chutes"
|
||||
models: "bge-m3"
|
||||
baseurl: "https://chutes-baai-bge-m3.chutes.ai"
|
||||
token: "cpk_xxxx"
|
||||
generator:
|
||||
provider: "chutes"
|
||||
models:
|
||||
- 'deepseek-ai/DeepSeek-R1'
|
||||
baseurl: "https://llm.chutes.ai/v1"
|
||||
token: "cpk_xxxx"
|
||||
|
||||
embedding:
|
||||
max_batch: 64
|
||||
dimension: 1024 #维度
|
||||
max_chars: 8000
|
||||
rate_limit_tpm: 120000
|
||||
|
||||
chunking:
|
||||
embed_toc: true
|
||||
max_tokens: 800
|
||||
overlap_tokens: 80
|
||||
prefer_heading_split: true
|
||||
include_exts: [".md", ".mdx"]
|
||||
ignore_dirs: [".git", "node_modules", "dist", "build"]
|
||||
|
||||
api:
|
||||
askai:
|
||||
timeout: 100
|
||||
retries: 3
|
||||
@ -21,11 +21,11 @@ sync:
|
||||
models:
|
||||
embedder:
|
||||
models: "bge-m3"
|
||||
baseurl: "http://127.0.0.1:9000"
|
||||
endpoint: "http://127.0.0.1:9000/v1/embeddings"
|
||||
generator:
|
||||
models:
|
||||
- 'deepseek-r1:8b'
|
||||
baseurl: "http://127.0.0.1:11434"
|
||||
endpoint: "http://127.0.0.1:11434/v1/chat/completions"
|
||||
|
||||
embedding:
|
||||
max_batch: 64
|
||||
|
||||
@ -23,22 +23,22 @@ sync:
|
||||
models:
|
||||
embedder:
|
||||
models: "bge-m3"
|
||||
baseurl: "http://127.0.0.1:9000"
|
||||
endpoint: "http://127.0.0.1:9000/v1/embeddings"
|
||||
generator:
|
||||
models:
|
||||
- 'llama2:13b'
|
||||
baseurl: "http://127.0.0.1:11434"
|
||||
endpoint: "http://127.0.0.1:11434/v1/chat/completions"
|
||||
token: ""
|
||||
# For PROD
|
||||
#models:
|
||||
# embedder:
|
||||
#models: "bge-m3"
|
||||
#baseurl: "https://chutes-baai-bge-m3.chutes.ai"
|
||||
#endpoint: "https://chutes-baai-bge-m3.chutes.ai/embed"
|
||||
#token: "cpk_xxxx"
|
||||
# generator:
|
||||
#models:
|
||||
# - 'moonshotai/Kimi-K2-Instruct'
|
||||
#baseurl: "https://llm.chutes.ai/v1"
|
||||
#endpoint: "https://llm.chutes.ai/v1"
|
||||
#token: "cpk_xxxx"
|
||||
|
||||
embedding:
|
||||
|
||||
@ -17,7 +17,7 @@ type Config struct {
|
||||
Generator struct {
|
||||
Provider string `yaml:"provider"`
|
||||
Models []string `yaml:"models"`
|
||||
BaseURL string `yaml:"baseurl"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Token string `yaml:"token"`
|
||||
} `yaml:"generator"`
|
||||
} `yaml:"models"`
|
||||
@ -43,8 +43,8 @@ func loadConfig() {
|
||||
if g.Token != "" {
|
||||
os.Setenv("CHUTES_API_TOKEN", g.Token)
|
||||
}
|
||||
if g.BaseURL != "" {
|
||||
os.Setenv("CHUTES_API_URL", g.BaseURL)
|
||||
if g.Endpoint != "" {
|
||||
os.Setenv("CHUTES_API_URL", g.Endpoint)
|
||||
}
|
||||
if len(g.Models) > 0 {
|
||||
os.Setenv("CHUTES_API_MODEL", g.Models[0])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user