diff --git a/client/main.go b/client/main.go index 2a59a45..bf947f7 100644 --- a/client/main.go +++ b/client/main.go @@ -69,10 +69,15 @@ var rootCmd = &cobra.Command{ chunkCfg := cfg.ResolveChunking() var embedder embed.Embedder - if embCfg.Model != "" { - embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension) - } else { - embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension) + switch embCfg.Provider { + case "allama": + embedder = embed.NewAllama(embCfg.BaseURL, embCfg.Model, embCfg.Dimension) + default: + if embCfg.Model != "" { + embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension) + } else { + embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension) + } } baseURL := os.Getenv("SERVER_URL") diff --git a/internal/rag/config/runtime.go b/internal/rag/config/runtime.go index 120fc21..2976a54 100644 --- a/internal/rag/config/runtime.go +++ b/internal/rag/config/runtime.go @@ -10,6 +10,7 @@ import ( // RuntimeEmbedding is the resolved embedding configuration used at runtime. type RuntimeEmbedding struct { + Provider string BaseURL string APIKey string Model string @@ -23,6 +24,7 @@ type RuntimeEmbedding struct { func (c *Config) ResolveEmbedding() RuntimeEmbedding { e := c.Embedding var rt RuntimeEmbedding + rt.Provider = e.Provider rt.Model = e.Model rt.Dimension = e.Dimension rt.RateLimitTPM = e.RateLimitTPM @@ -83,8 +85,10 @@ type Runtime struct { Datasources []DataSource `yaml:"datasources"` Proxy string `yaml:"proxy"` Embedding struct { + Provider string `yaml:"provider"` BaseURL string `yaml:"base_url"` Token string `yaml:"token"` + Model string `yaml:"model"` Dimension int `yaml:"dimension"` } `yaml:"embedding"` } @@ -117,8 +121,10 @@ func (rt *Runtime) ToConfig() *Config { c.Global.VectorDB = rt.VectorDB c.Global.Datasources = rt.Datasources c.Global.Proxy = rt.Proxy + c.Embedding.Provider = rt.Embedding.Provider c.Embedding.BaseURL = rt.Embedding.BaseURL c.Embedding.Token = rt.Embedding.Token + c.Embedding.Model = rt.Embedding.Model c.Embedding.Dimension = rt.Embedding.Dimension return &c } diff --git a/internal/rag/embed/allama.go b/internal/rag/embed/allama.go new file mode 100644 index 0000000..592e415 --- /dev/null +++ b/internal/rag/embed/allama.go @@ -0,0 +1,76 @@ +package embed + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// Allama implements the Embedder interface using the Allama/Ollama embeddings API. +type Allama struct { + baseURL string + model string + dim int + client *http.Client +} + +// NewAllama creates a new Allama embedder. +func NewAllama(baseURL, model string, dim int) *Allama { + return &Allama{ + baseURL: strings.TrimRight(baseURL, "/"), + model: model, + dim: dim, + client: &http.Client{Timeout: 30 * time.Second}, + } +} + +// Dimension returns the embedding dimension if known. +func (a *Allama) Dimension() int { return a.dim } + +// Embed posts texts to the Allama embeddings endpoint. +func (a *Allama) Embed(ctx context.Context, inputs []string) ([][]float32, int, error) { + vecs := make([][]float32, len(inputs)) + url := a.baseURL + "/api/embeddings" + for i, text := range inputs { + payload := map[string]any{"model": a.model, "prompt": text} + body, _ := json.Marshal(payload) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, 0, err + } + req.Header.Set("Content-Type", "application/json") + resp, err := a.client.Do(req) + if err != nil { + return nil, 0, err + } + if resp.StatusCode >= 300 { + resp.Body.Close() + return nil, 0, fmt.Errorf("embed failed: %s", resp.Status) + } + data, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, 0, err + } + var out struct { + Embedding []float64 `json:"embedding"` + } + if err := json.Unmarshal(data, &out); err != nil { + return nil, 0, err + } + if a.dim == 0 { + a.dim = len(out.Embedding) + } + vec := make([]float32, len(out.Embedding)) + for j, v := range out.Embedding { + vec[j] = float32(v) + } + vecs[i] = vec + } + return vecs, 0, nil +} diff --git a/internal/rag/ingest/ingest.go b/internal/rag/ingest/ingest.go index 847f8ec..a015c58 100644 --- a/internal/rag/ingest/ingest.go +++ b/internal/rag/ingest/ingest.go @@ -64,10 +64,15 @@ func IngestRepo(ctx context.Context, cfg *cfgpkg.Config, ds cfgpkg.DataSource, o defer conn.Close(ctx) var embedder embed.Embedder - if embCfg.Model != "" { - embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension) - } else { - embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension) + switch embCfg.Provider { + case "allama": + embedder = embed.NewAllama(embCfg.BaseURL, embCfg.Model, embCfg.Dimension) + default: + if embCfg.Model != "" { + embedder = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension) + } else { + embedder = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension) + } } if err := store.EnsureSchema(ctx, conn, embedder.Dimension(), opt.MigrateDim); err != nil { st.Errors = append(st.Errors, err) diff --git a/internal/rag/service.go b/internal/rag/service.go index e761766..2cce100 100644 --- a/internal/rag/service.go +++ b/internal/rag/service.go @@ -60,10 +60,15 @@ func (s *Service) Query(ctx context.Context, question string, limit int) ([]Docu return nil, nil } var emb embed.Embedder - if embCfg.Model != "" { - emb = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension) - } else { - emb = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension) + switch embCfg.Provider { + case "allama": + emb = embed.NewAllama(embCfg.BaseURL, embCfg.Model, embCfg.Dimension) + default: + if embCfg.Model != "" { + emb = embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension) + } else { + emb = embed.NewBGE(embCfg.BaseURL, embCfg.APIKey, embCfg.Dimension) + } } vecs, _, err := emb.Embed(ctx, []string{question}) if err != nil { diff --git a/server/api/askai.go b/server/api/askai.go index 2f3db11..5b4fcc6 100644 --- a/server/api/askai.go +++ b/server/api/askai.go @@ -17,7 +17,7 @@ import ( ) // askFn performs the chat completion request. It is replaceable in tests. -var askFn = callChutes +var askFn = callLLM // registerAskAIRoutes wires the /api/askai endpoint. func registerAskAIRoutes(r *gin.RouterGroup) { @@ -61,9 +61,10 @@ type serverConfig struct { } `yaml:"api"` } -// loadConfig reads model, URL, timeout and retries from ConfigPath and -// environment variables. The Chutes API token is sourced from the config file. -func loadConfig() (string, string, string, time.Duration, int) { +// loadConfig reads provider, model, URL, timeout and retries from ConfigPath +// and environment variables. +func loadConfig() (string, string, string, string, time.Duration, int) { + provider := "" model := os.Getenv("CHUTES_API_MODEL") baseURL := os.Getenv("CHUTES_API_URL") token := "" @@ -74,19 +75,28 @@ func loadConfig() (string, string, string, time.Duration, int) { var cfg serverConfig if err := yaml.Unmarshal(data, &cfg); err == nil { for _, p := range cfg.Provider { - if p.Name != "chutes" { - continue + if provider == "" { + provider = p.Name } - if token == "" { - token = p.Token + switch p.Name { + case "allama": + if model == "" && len(p.Models) > 0 { + model = p.Models[0] + } + if baseURL == "" { + baseURL = p.BaseURL + } + case "chutes": + if token == "" { + token = p.Token + } + if model == "" && len(p.Models) > 0 { + model = p.Models[0] + } + if baseURL == "" { + baseURL = p.BaseURL + } } - if model == "" && len(p.Models) > 0 { - model = p.Models[0] - } - if baseURL == "" { - baseURL = p.BaseURL - } - break } if cfg.API.AskAI.Timeout > 0 { timeout = time.Duration(cfg.API.AskAI.Timeout) * time.Second @@ -102,20 +112,28 @@ func loadConfig() (string, string, string, time.Duration, int) { if retries > 3 { retries = 3 } - if model == "" { - model = "deepseek-ai/DeepSeek-R1" - } + provider = strings.ToLower(provider) baseURL = strings.TrimRight(baseURL, "/") + if provider == "allama" { + if baseURL == "" { + baseURL = "http://localhost:11434" + } + if model == "" { + model = "gpt-oss:20b" + } + return provider, token, model, baseURL + "/api/chat", timeout, retries + } if baseURL == "" { baseURL = "https://llm.chutes.ai" } - url := baseURL + "/v1/chat/completions" - return token, model, url, timeout, retries + if model == "" { + model = "deepseek-ai/DeepSeek-R1" + } + return "chutes", token, model, baseURL + "/v1/chat/completions", timeout, retries } // callChutes sends the question to the hosted LLM service and returns the reply. -func callChutes(question string) (string, error) { - token, model, url, timeout, retries := loadConfig() +func callChutes(token, model, url string, timeout time.Duration, retries int, question string) (string, error) { if token == "" || token == "cpk_xxxxxxx" { return "", errors.New("chutes token not set") } @@ -181,3 +199,65 @@ func callChutes(question string) (string, error) { } return "", lastErr } + +// callAllama sends the question to a local Allama server. +func callAllama(model, url string, timeout time.Duration, retries int, question string) (string, error) { + reqBody := map[string]any{ + "model": model, + "messages": []any{map[string]string{"role": "user", "content": question}}, + "stream": false, + } + data, err := json.Marshal(reqBody) + if err != nil { + return "", err + } + client := &http.Client{Timeout: timeout} + var lastErr error + for i := 0; i <= retries; i++ { + req, err := http.NewRequest("POST", url, bytes.NewReader(data)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + lastErr = err + continue + } + b, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + lastErr = err + continue + } + if resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("allama API error: %s", string(b)) + continue + } + var res struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } + if err := json.Unmarshal(b, &res); err != nil { + lastErr = err + continue + } + return res.Message.Content, nil + } + if lastErr == nil { + lastErr = errors.New("request failed") + } + return "", lastErr +} + +// callLLM dispatches the question to the configured provider. +func callLLM(question string) (string, error) { + provider, token, model, url, timeout, retries := loadConfig() + switch provider { + case "allama": + return callAllama(model, url, timeout, retries, question) + default: + return callChutes(token, model, url, timeout, retries, question) + } +} diff --git a/server/api/rag_test.go b/server/api/rag_test.go index 68bf6e9..438a531 100644 --- a/server/api/rag_test.go +++ b/server/api/rag_test.go @@ -120,7 +120,7 @@ func TestRAGUpsert_DimensionMismatch(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() r.ServeHTTP(w, req) - if w.Code != http.StatusInternalServerError { - t.Fatalf("expected status 500, got %d", w.Code) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("expected status 503, got %d", w.Code) } } diff --git a/server/config/config.go b/server/config/config.go index 8d99658..40afa47 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -59,6 +59,13 @@ type Global struct { VectorDB VectorDB `yaml:"vectordb"` Datasources []Datasource `yaml:"datasources"` Proxy string `yaml:"proxy"` + Embedding struct { + Provider string `yaml:"provider"` + BaseURL string `yaml:"base_url"` + Token string `yaml:"token"` + Model string `yaml:"model"` + Dimension int `yaml:"dimension"` + } `yaml:"embedding"` } type Provider struct { diff --git a/server/config/server.yaml b/server/config/server.yaml index be23c9d..f5c61c8 100644 --- a/server/config/server.yaml +++ b/server/config/server.yaml @@ -1,15 +1,25 @@ log: level: info global: + #proxy: socks5://127.0.0.1:1080 # optional redis: addr: 127.0.0.1:6379 + password: "" vectordb: pgurl: postgres://user:password@127.0.0.1:5432/postgres datasources: [] embedding: - base_url: http://127.0.0.1:11434 + provider: allama + base_url: http://localhost:11434 token: "" + model: bge-m3 dimension: 1536 +provider: + - name: allama + base_url: http://localhost:11434 + token: "" + models: + - gpt-oss:20b api: askai: timeout: 100