feat: add allama support

This commit is contained in:
shenlan 2025-08-10 12:38:15 +08:00
parent 47a8ff25e4
commit fe9b4f7a4d
9 changed files with 231 additions and 37 deletions

View File

@ -69,10 +69,15 @@ var rootCmd = &cobra.Command{
chunkCfg := cfg.ResolveChunking()
var embedder embed.Embedder
if embCfg.Model != "" {
embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
} else {
embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
switch embCfg.Provider {
case "allama":
embedder = embed.NewAllama(embCfg.BaseURL, embCfg.Model, embCfg.Dimension)
default:
if embCfg.Model != "" {
embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
} else {
embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
}
}
baseURL := os.Getenv("SERVER_URL")

View File

@ -10,6 +10,7 @@ import (
// RuntimeEmbedding is the resolved embedding configuration used at runtime.
type RuntimeEmbedding struct {
Provider string
BaseURL string
APIKey string
Model string
@ -23,6 +24,7 @@ type RuntimeEmbedding struct {
func (c *Config) ResolveEmbedding() RuntimeEmbedding {
e := c.Embedding
var rt RuntimeEmbedding
rt.Provider = e.Provider
rt.Model = e.Model
rt.Dimension = e.Dimension
rt.RateLimitTPM = e.RateLimitTPM
@ -83,8 +85,10 @@ type Runtime struct {
Datasources []DataSource `yaml:"datasources"`
Proxy string `yaml:"proxy"`
Embedding struct {
Provider string `yaml:"provider"`
BaseURL string `yaml:"base_url"`
Token string `yaml:"token"`
Model string `yaml:"model"`
Dimension int `yaml:"dimension"`
} `yaml:"embedding"`
}
@ -117,8 +121,10 @@ func (rt *Runtime) ToConfig() *Config {
c.Global.VectorDB = rt.VectorDB
c.Global.Datasources = rt.Datasources
c.Global.Proxy = rt.Proxy
c.Embedding.Provider = rt.Embedding.Provider
c.Embedding.BaseURL = rt.Embedding.BaseURL
c.Embedding.Token = rt.Embedding.Token
c.Embedding.Model = rt.Embedding.Model
c.Embedding.Dimension = rt.Embedding.Dimension
return &c
}

View File

@ -0,0 +1,76 @@
package embed
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// Allama implements the Embedder interface using the Allama/Ollama embeddings API.
type Allama struct {
baseURL string
model string
dim int
client *http.Client
}
// NewAllama creates a new Allama embedder.
func NewAllama(baseURL, model string, dim int) *Allama {
return &Allama{
baseURL: strings.TrimRight(baseURL, "/"),
model: model,
dim: dim,
client: &http.Client{Timeout: 30 * time.Second},
}
}
// Dimension returns the embedding dimension if known.
func (a *Allama) Dimension() int { return a.dim }
// Embed posts texts to the Allama embeddings endpoint.
func (a *Allama) Embed(ctx context.Context, inputs []string) ([][]float32, int, error) {
vecs := make([][]float32, len(inputs))
url := a.baseURL + "/api/embeddings"
for i, text := range inputs {
payload := map[string]any{"model": a.model, "prompt": text}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, 0, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := a.client.Do(req)
if err != nil {
return nil, 0, err
}
if resp.StatusCode >= 300 {
resp.Body.Close()
return nil, 0, fmt.Errorf("embed failed: %s", resp.Status)
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, 0, err
}
var out struct {
Embedding []float64 `json:"embedding"`
}
if err := json.Unmarshal(data, &out); err != nil {
return nil, 0, err
}
if a.dim == 0 {
a.dim = len(out.Embedding)
}
vec := make([]float32, len(out.Embedding))
for j, v := range out.Embedding {
vec[j] = float32(v)
}
vecs[i] = vec
}
return vecs, 0, nil
}

View File

@ -64,10 +64,15 @@ func IngestRepo(ctx context.Context, cfg *cfgpkg.Config, ds cfgpkg.DataSource, o
defer conn.Close(ctx)
var embedder embed.Embedder
if embCfg.Model != "" {
embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
} else {
embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
switch embCfg.Provider {
case "allama":
embedder = embed.NewAllama(embCfg.BaseURL, embCfg.Model, embCfg.Dimension)
default:
if embCfg.Model != "" {
embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
} else {
embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
}
}
if err := store.EnsureSchema(ctx, conn, embedder.Dimension(), opt.MigrateDim); err != nil {
st.Errors = append(st.Errors, err)

View File

@ -60,10 +60,15 @@ func (s *Service) Query(ctx context.Context, question string, limit int) ([]Docu
return nil, nil
}
var emb embed.Embedder
if embCfg.Model != "" {
emb = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
} else {
emb = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
switch embCfg.Provider {
case "allama":
emb = embed.NewAllama(embCfg.BaseURL, embCfg.Model, embCfg.Dimension)
default:
if embCfg.Model != "" {
emb = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
} else {
emb = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
}
}
vecs, _, err := emb.Embed(ctx, []string{question})
if err != nil {

View File

@ -17,7 +17,7 @@ import (
)
// askFn performs the chat completion request. It is replaceable in tests.
var askFn = callChutes
var askFn = callLLM
// registerAskAIRoutes wires the /api/askai endpoint.
func registerAskAIRoutes(r *gin.RouterGroup) {
@ -61,9 +61,10 @@ type serverConfig struct {
} `yaml:"api"`
}
// loadConfig reads model, URL, timeout and retries from ConfigPath and
// environment variables. The Chutes API token is sourced from the config file.
func loadConfig() (string, string, string, time.Duration, int) {
// loadConfig reads provider, model, URL, timeout and retries from ConfigPath
// and environment variables.
func loadConfig() (string, string, string, string, time.Duration, int) {
provider := ""
model := os.Getenv("CHUTES_API_MODEL")
baseURL := os.Getenv("CHUTES_API_URL")
token := ""
@ -74,19 +75,28 @@ func loadConfig() (string, string, string, time.Duration, int) {
var cfg serverConfig
if err := yaml.Unmarshal(data, &cfg); err == nil {
for _, p := range cfg.Provider {
if p.Name != "chutes" {
continue
if provider == "" {
provider = p.Name
}
if token == "" {
token = p.Token
switch p.Name {
case "allama":
if model == "" && len(p.Models) > 0 {
model = p.Models[0]
}
if baseURL == "" {
baseURL = p.BaseURL
}
case "chutes":
if token == "" {
token = p.Token
}
if model == "" && len(p.Models) > 0 {
model = p.Models[0]
}
if baseURL == "" {
baseURL = p.BaseURL
}
}
if model == "" && len(p.Models) > 0 {
model = p.Models[0]
}
if baseURL == "" {
baseURL = p.BaseURL
}
break
}
if cfg.API.AskAI.Timeout > 0 {
timeout = time.Duration(cfg.API.AskAI.Timeout) * time.Second
@ -102,20 +112,28 @@ func loadConfig() (string, string, string, time.Duration, int) {
if retries > 3 {
retries = 3
}
if model == "" {
model = "deepseek-ai/DeepSeek-R1"
}
provider = strings.ToLower(provider)
baseURL = strings.TrimRight(baseURL, "/")
if provider == "allama" {
if baseURL == "" {
baseURL = "http://localhost:11434"
}
if model == "" {
model = "gpt-oss:20b"
}
return provider, token, model, baseURL + "/api/chat", timeout, retries
}
if baseURL == "" {
baseURL = "https://llm.chutes.ai"
}
url := baseURL + "/v1/chat/completions"
return token, model, url, timeout, retries
if model == "" {
model = "deepseek-ai/DeepSeek-R1"
}
return "chutes", token, model, baseURL + "/v1/chat/completions", timeout, retries
}
// callChutes sends the question to the hosted LLM service and returns the reply.
func callChutes(question string) (string, error) {
token, model, url, timeout, retries := loadConfig()
func callChutes(token, model, url string, timeout time.Duration, retries int, question string) (string, error) {
if token == "" || token == "cpk_xxxxxxx" {
return "", errors.New("chutes token not set")
}
@ -181,3 +199,65 @@ func callChutes(question string) (string, error) {
}
return "", lastErr
}
// callAllama sends the question to a local Allama server.
func callAllama(model, url string, timeout time.Duration, retries int, question string) (string, error) {
reqBody := map[string]any{
"model": model,
"messages": []any{map[string]string{"role": "user", "content": question}},
"stream": false,
}
data, err := json.Marshal(reqBody)
if err != nil {
return "", err
}
client := &http.Client{Timeout: timeout}
var lastErr error
for i := 0; i <= retries; i++ {
req, err := http.NewRequest("POST", url, bytes.NewReader(data))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
lastErr = err
continue
}
b, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = err
continue
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("allama API error: %s", string(b))
continue
}
var res struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
}
if err := json.Unmarshal(b, &res); err != nil {
lastErr = err
continue
}
return res.Message.Content, nil
}
if lastErr == nil {
lastErr = errors.New("request failed")
}
return "", lastErr
}
// callLLM dispatches the question to the configured provider.
func callLLM(question string) (string, error) {
provider, token, model, url, timeout, retries := loadConfig()
switch provider {
case "allama":
return callAllama(model, url, timeout, retries, question)
default:
return callChutes(token, model, url, timeout, retries, question)
}
}

View File

@ -120,7 +120,7 @@ func TestRAGUpsert_DimensionMismatch(t *testing.T) {
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected status 500, got %d", w.Code)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("expected status 503, got %d", w.Code)
}
}

View File

@ -59,6 +59,13 @@ type Global struct {
VectorDB VectorDB `yaml:"vectordb"`
Datasources []Datasource `yaml:"datasources"`
Proxy string `yaml:"proxy"`
Embedding struct {
Provider string `yaml:"provider"`
BaseURL string `yaml:"base_url"`
Token string `yaml:"token"`
Model string `yaml:"model"`
Dimension int `yaml:"dimension"`
} `yaml:"embedding"`
}
type Provider struct {

View File

@ -1,15 +1,25 @@
log:
level: info
global:
#proxy: socks5://127.0.0.1:1080 # optional
redis:
addr: 127.0.0.1:6379
password: ""
vectordb:
pgurl: postgres://user:password@127.0.0.1:5432/postgres
datasources: []
embedding:
base_url: http://127.0.0.1:11434
provider: allama
base_url: http://localhost:11434
token: ""
model: bge-m3
dimension: 1536
provider:
- name: allama
base_url: http://localhost:11434
token: ""
models:
- gpt-oss:20b
api:
askai:
timeout: 100