feat: add bge embedder
This commit is contained in:
parent
69584ad4e3
commit
58ed219c95
@ -72,7 +72,12 @@ func main() {
|
||||
|
||||
embCfg := cfg.ResolveEmbedding()
|
||||
chunkCfg := cfg.ResolveChunking()
|
||||
embedder := embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
var embedder embed.Embedder
|
||||
if embCfg.Model != "" {
|
||||
embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
} else {
|
||||
embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
|
||||
}
|
||||
var syncErrs []string
|
||||
|
||||
for _, ds := range cfg.Global.Datasources {
|
||||
|
||||
@ -28,8 +28,8 @@ provider:
|
||||
- 'moonshotai/Kimi-K2-Instruct'
|
||||
|
||||
embedding:
|
||||
provider: chutes # openai|azure|custom
|
||||
model: "BAAI/bge-m3"
|
||||
base_url: https://chutes-baai-bge-m3.chutes.ai/embed
|
||||
token: "cpk_xxxxxxxxxxxxxxxxxx"
|
||||
dimension: 0 # 0 = 首次响应自动探测维度
|
||||
rate_limit_tpm: 120000
|
||||
max_batch: 64
|
||||
|
||||
@ -68,6 +68,7 @@ type Provider struct {
|
||||
type EmbeddingCfg struct {
|
||||
Provider string `yaml:"provider"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
Token string `yaml:"token"`
|
||||
Model string `yaml:"model"`
|
||||
APIKeyEnv string `yaml:"api_key_env"`
|
||||
Dimension int `yaml:"dimension"`
|
||||
|
||||
@ -46,6 +46,8 @@ func (c *Config) ResolveEmbedding() RuntimeEmbedding {
|
||||
|
||||
if e.APIKeyEnv != "" {
|
||||
rt.APIKey = os.Getenv(e.APIKeyEnv)
|
||||
} else if e.Token != "" {
|
||||
rt.APIKey = e.Token
|
||||
} else if prov != nil {
|
||||
rt.APIKey = prov.Token
|
||||
}
|
||||
|
||||
@ -4,39 +4,66 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BGE calls a local bge-m3 embedding service.
|
||||
// BGE implements the Embedder interface for a BGE embedding service.
|
||||
type BGE struct {
|
||||
Endpoint string
|
||||
Client *http.Client
|
||||
baseURL string
|
||||
token string
|
||||
dim int
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewBGE returns a new BGE embedder.
|
||||
func NewBGE(endpoint string) *BGE {
|
||||
return &BGE{Endpoint: endpoint, Client: &http.Client{}}
|
||||
func NewBGE(baseURL, token string, dim int) *BGE {
|
||||
return &BGE{
|
||||
baseURL: baseURL,
|
||||
token: token,
|
||||
dim: dim,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// Embed posts text to the bge service and parses the vector.
|
||||
func (b *BGE) Embed(ctx context.Context, text string) ([]float32, error) {
|
||||
body := map[string]string{"text": text}
|
||||
data, _ := json.Marshal(body)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", b.Endpoint, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Dimension returns the embedding dimension if known.
|
||||
func (b *BGE) Dimension() int { return b.dim }
|
||||
|
||||
// Embed posts texts to the BGE service and returns embeddings.
|
||||
func (b *BGE) Embed(ctx context.Context, inputs []string) ([][]float32, int, error) {
|
||||
vecs := make([][]float32, len(inputs))
|
||||
for i, text := range inputs {
|
||||
payload := map[string]any{"inputs": text}
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, b.baseURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if b.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+b.token)
|
||||
}
|
||||
resp, err := b.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
resp.Body.Close()
|
||||
return nil, 0, fmt.Errorf("embed failed: %s", resp.Status)
|
||||
}
|
||||
var out struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
resp.Body.Close()
|
||||
return nil, 0, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
if b.dim == 0 {
|
||||
b.dim = len(out.Embedding)
|
||||
}
|
||||
vecs[i] = out.Embedding
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := b.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var res struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.Embedding, nil
|
||||
return vecs, 0, nil
|
||||
}
|
||||
|
||||
@ -63,7 +63,12 @@ func IngestRepo(ctx context.Context, cfg *cfgpkg.Config, ds cfgpkg.DataSource, o
|
||||
}
|
||||
defer conn.Close(ctx)
|
||||
|
||||
embedder := embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
var embedder embed.Embedder
|
||||
if embCfg.Model != "" {
|
||||
embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
} else {
|
||||
embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
|
||||
}
|
||||
if err := store.EnsureSchema(ctx, conn, embedder.Dimension(), opt.MigrateDim); err != nil {
|
||||
st.Errors = append(st.Errors, err)
|
||||
return st, err
|
||||
|
||||
@ -55,10 +55,15 @@ func (s *Service) Query(ctx context.Context, question string, limit int) ([]Docu
|
||||
return nil, nil
|
||||
}
|
||||
embCfg := s.cfg.ResolveEmbedding()
|
||||
if embCfg.APIKey == "" || embCfg.BaseURL == "" || embCfg.Model == "" {
|
||||
if embCfg.APIKey == "" || embCfg.BaseURL == "" {
|
||||
return nil, nil
|
||||
}
|
||||
emb := embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
var emb embed.Embedder
|
||||
if embCfg.Model != "" {
|
||||
emb = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
} else {
|
||||
emb = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
|
||||
}
|
||||
vecs, _, err := emb.Embed(ctx, []string{question})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
Loading…
Reference in New Issue
Block a user