diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 0792388..05d465f 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -7,20 +7,11 @@ import ( "net/http" "os" "time" - - rconfig "xcontrol/server/rag/config" ) // main loads server RAG configuration and triggers a manual sync by // calling the running API server's /api/rag/sync endpoint. func main() { - cfg, err := rconfig.LoadServer() - if err != nil { - log.Printf("read config: %v", err) - } else { - log.Printf("loaded %d datasource(s)", len(cfg.Datasources)) - } - baseURL := os.Getenv("SERVER_URL") if baseURL == "" { baseURL = "http://localhost:8080" diff --git a/cmd/ingest/main.go b/cmd/ingest/main.go new file mode 100644 index 0000000..0e7b2a2 --- /dev/null +++ b/cmd/ingest/main.go @@ -0,0 +1,40 @@ +package main + +import ( + "context" + "flag" + "log" + "runtime" + + cfgpkg "xcontrol/server/rag/config" + "xcontrol/server/rag/ingest" +) + +func main() { + configPath := flag.String("config", "example/server/config/server.yaml", "config path") + onlyRepo := flag.String("only-repo", "", "only ingest repo by name") + dryRun := flag.Bool("dry-run", false, "dry run") + maxFiles := flag.Int("max-files", 0, "limit number of files") + migrateDim := flag.Bool("migrate-dim", false, "auto migrate embedding dimension") + concurrency := flag.Int("concurrency", runtime.NumCPU()*2, "concurrent workers") + flag.Parse() + + cfg, err := cfgpkg.Load(*configPath) + if err != nil { + log.Fatalf("load config: %v", err) + } + + ctx := context.Background() + opt := ingest.Options{MaxFiles: *maxFiles, DryRun: *dryRun, MigrateDim: *migrateDim, Concurrency: *concurrency} + + for _, ds := range cfg.Global.Datasources { + if *onlyRepo != "" && ds.Name != *onlyRepo { + continue + } + st, err := ingest.IngestRepo(ctx, cfg, ds, opt) + if err != nil { + log.Printf("ingest %s error: %v", ds.Name, err) + } + log.Printf("%s: files_scanned=%d chunks_built=%d embeddings_created=%d rows_upserted=%d elapsed=%s", ds.Name, st.FilesScanned, st.ChunksBuilt, st.EmbeddingsCreated, st.RowsUpserted, st.Elapsed) + } +} diff --git a/example/server/config/server.yaml b/example/server/config/server.yaml index bb4d36d..bbe6020 100644 --- a/example/server/config/server.yaml +++ b/example/server/config/server.yaml @@ -14,11 +14,32 @@ global: - name: documents repo: https://github.com/svc-design/documents path: / -llm: - url: https://llm.chutes.ai/ - token: "cpk_xxxxxxx" - models: - - 'moonshotai/Kimi-K2-Instruct' + +provider: + - name: chutes + base_url: https://llm.chutes.ai/ + token: "cpk_xxxxxxx" + models: + - 'moonshotai/Kimi-K2-Instruct' + +# embedding 段为空时自动复用 llm(见代码回退逻辑) +embedding: + provider: chutes # openai|azure|custom + base_url: "" # 为空 -> 用 provider base_url: + "/v1" + model: "text-embedding-3-large" + api_key_env: "" # 为空 -> 直接用 llm.token + dimension: 0 # 0 = 首次响应自动探测维度 + rate_limit_tpm: 120000 + max_batch: 64 + max_chars: 8000 + +chunking: + max_tokens: 800 + overlap_tokens: 80 + prefer_heading_split: true + include_exts: [".md", ".mdx"] + ignore_dirs: [".git", "node_modules", "dist", "build"] + api: askai: timeout: 100 diff --git a/go.mod b/go.mod index dbaa115..b093124 100644 --- a/go.mod +++ b/go.mod @@ -5,13 +5,15 @@ go 1.23.0 toolchain go1.23.8 require ( - github.com/gin-gonic/gin v1.9.1 - github.com/go-git/go-git/v5 v5.16.2 - github.com/jackc/pgx/v5 v5.7.5 - github.com/redis/go-redis/v9 v9.12.0 - github.com/yuin/goldmark v1.7.13 - gopkg.in/yaml.v3 v3.0.1 - gorm.io/gorm v1.25.2 + github.com/gin-gonic/gin v1.9.1 + github.com/go-git/go-git/v5 v5.16.2 + github.com/jackc/pgx/v5 v5.7.5 + github.com/redis/go-redis/v9 v9.12.0 + github.com/yuin/goldmark v1.7.13 + github.com/pgvector/pgvector-go v0.3.0 + github.com/pkoukk/tiktoken-go v0.1.7 + gopkg.in/yaml.v3 v3.0.1 + gorm.io/gorm v1.25.5 ) require ( @@ -24,6 +26,7 @@ require ( github.com/cloudflare/circl v1.6.1 // indirect github.com/cyphar/filepath-securejoin v0.4.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect @@ -34,6 +37,7 @@ require ( github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect @@ -48,7 +52,9 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pgvector/pgvector-go v0.3.0 // indirect github.com/pjbgf/sha1cd v0.3.2 // indirect + github.com/pkoukk/tiktoken-go v0.1.7 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.3.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/go.sum b/go.sum index 5851550..813f97a 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= @@ -65,6 +67,8 @@ github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUv github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -106,10 +110,14 @@ github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pgvector/pgvector-go v0.3.0 h1:Ij+Yt78R//uYqs3Zk35evZFvr+G0blW0OUN+Q2D1RWc= +github.com/pgvector/pgvector-go v0.3.0/go.mod h1:duFy+PXWfW7QQd5ibqutBO4GxLsUZ9RVXhFZGIBsWSA= github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4= github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= +github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.12.0 h1:XlVPGlflh4nxfhsNXPA8Qp6EmEfTo0rp8oaBzPipXnU= @@ -188,4 +196,6 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho= gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/server/rag/config/config.go b/server/rag/config/config.go index fe10208..a809a44 100644 --- a/server/rag/config/config.go +++ b/server/rag/config/config.go @@ -1,25 +1,98 @@ package config import ( - "gopkg.in/yaml.v3" + "fmt" "os" + + "gopkg.in/yaml.v3" ) -// Repo holds configuration for a single Git repository and paths to index. -type Repo struct { - URL string `yaml:"url"` - Branch string `yaml:"branch"` - Paths []string `yaml:"paths"` - Local string `yaml:"local"` +// DataSource represents a repository and a path to ingest. +type DataSource struct { + Name string `yaml:"name"` + Repo string `yaml:"repo"` + Path string `yaml:"path"` } -// Config describes the RAG ingestion settings. +// VectorDB configuration for PostgreSQL with pgvector. +type VectorDB struct { + PGURL string `yaml:"pgurl"` + PGHost string `yaml:"pg_host"` + PGPort int `yaml:"pg_port"` + PGUser string `yaml:"pg_user"` + PGPassword string `yaml:"pg_password"` + PGDBName string `yaml:"pg_db_name"` + PGSSLMode string `yaml:"pg_sslmode"` +} + +// DSN returns the PostgreSQL connection string derived from individual fields +// when PGURL is not provided. +func (v VectorDB) DSN() string { + if v.PGURL != "" { + return v.PGURL + } + if v.PGHost == "" || v.PGUser == "" || v.PGDBName == "" { + return "" + } + port := v.PGPort + if port == 0 { + port = 5432 + } + ssl := v.PGSSLMode + if ssl == "" { + ssl = "require" + } + return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s", v.PGUser, v.PGPassword, v.PGHost, port, v.PGDBName, ssl) +} + +// Global configuration shared by server and CLI. +type Global struct { + Redis struct { + Addr string `yaml:"addr"` + Password string `yaml:"password"` + } `yaml:"redis"` + VectorDB VectorDB `yaml:"vectordb"` + Datasources []DataSource `yaml:"datasources"` +} + +// Provider defines an LLM provider which can also serve embeddings. +type Provider struct { + Name string `yaml:"name"` + BaseURL string `yaml:"base_url"` + Token string `yaml:"token"` + Models []string `yaml:"models"` +} + +// EmbeddingCfg describes embedding service settings. +type EmbeddingCfg struct { + Provider string `yaml:"provider"` + BaseURL string `yaml:"base_url"` + Model string `yaml:"model"` + APIKeyEnv string `yaml:"api_key_env"` + Dimension int `yaml:"dimension"` + RateLimitTPM int `yaml:"rate_limit_tpm"` + MaxBatch int `yaml:"max_batch"` + MaxChars int `yaml:"max_chars"` +} + +// ChunkingCfg controls how markdown is split into chunks. +type ChunkingCfg struct { + MaxTokens int `yaml:"max_tokens"` + OverlapTokens int `yaml:"overlap_tokens"` + PreferHeadingSplit bool `yaml:"prefer_heading_split"` + IncludeExts []string `yaml:"include_exts"` + IgnoreDirs []string `yaml:"ignore_dirs"` +} + +// Config is the root configuration for ingestion. type Config struct { - Repos []Repo `yaml:"repos"` - Embedder string `yaml:"embedder"` + Global Global `yaml:"global"` + Provider []Provider `yaml:"provider"` + Embedding EmbeddingCfg `yaml:"embedding"` + Chunking ChunkingCfg `yaml:"chunking"` } -// Load reads YAML configuration from path. +// Load reads YAML configuration from the given path. func Load(path string) (*Config, error) { b, err := os.ReadFile(path) if err != nil { diff --git a/server/rag/config/runtime.go b/server/rag/config/runtime.go index b3b9de9..01a64de 100644 --- a/server/rag/config/runtime.go +++ b/server/rag/config/runtime.go @@ -1,91 +1,69 @@ package config import ( - "fmt" "os" - "path/filepath" - - "gopkg.in/yaml.v3" + "strings" ) -// Runtime holds runtime configuration for RAG features. -type Datasource struct { - Name string `yaml:"name"` - Repo string `yaml:"repo"` - Path string `yaml:"path"` +// RuntimeEmbedding is the resolved embedding configuration used at runtime. +type RuntimeEmbedding struct { + BaseURL string + APIKey string + Model string + Dimension int + RateLimitTPM int + MaxBatch int + MaxChars int } -type Runtime struct { - Redis struct { - Addr string `yaml:"addr"` - Password string `yaml:"password"` - } `yaml:"redis"` - Module string `yaml:"module"` - VectorDB VectorDB `yaml:"vectordb"` - Datasources []Datasource `yaml:"datasources"` +// ResolveEmbedding applies fallback logic to produce runtime embedding settings. +func (c *Config) ResolveEmbedding() RuntimeEmbedding { + e := c.Embedding + var rt RuntimeEmbedding + rt.Model = e.Model + rt.Dimension = e.Dimension + rt.RateLimitTPM = e.RateLimitTPM + rt.MaxBatch = e.MaxBatch + rt.MaxChars = e.MaxChars + + // find provider by name + var prov *Provider + for i := range c.Provider { + if c.Provider[i].Name == e.Provider { + prov = &c.Provider[i] + break + } + } + + if e.BaseURL != "" { + rt.BaseURL = e.BaseURL + } else if prov != nil { + rt.BaseURL = strings.TrimRight(prov.BaseURL, "/") + "/v1" + } + + if e.APIKeyEnv != "" { + rt.APIKey = os.Getenv(e.APIKeyEnv) + } else if prov != nil { + rt.APIKey = prov.Token + } + + return rt } -// VectorDB holds configuration for the PostgreSQL vector store. -type VectorDB struct { - PGURL string `yaml:"pgurl"` - PGHost string `yaml:"pg_host"` - PGPort int `yaml:"pg_port"` - PGUser string `yaml:"pg_user"` - PGPassword string `yaml:"pg_password"` - PGDBName string `yaml:"pg_db_name"` - PGSSLMode string `yaml:"pg_sslmode"` -} - -// DSN returns the connection string for the database. -// If PGURL is provided it is used, otherwise a DSN is constructed -// from individual fields. When insufficient fields are provided it -// returns an empty string. -func (v VectorDB) DSN() string { - if v.PGURL != "" { - return v.PGURL - } - if v.PGHost == "" || v.PGUser == "" || v.PGDBName == "" { - return "" - } - port := v.PGPort - if port == 0 { - port = 5432 - } - ssl := v.PGSSLMode - if ssl == "" { - ssl = "require" - } - return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s", v.PGUser, v.PGPassword, v.PGHost, port, v.PGDBName, ssl) -} - -// LoadServer loads RAG configuration from server/config/server.yaml. -func LoadServer() (*Runtime, error) { - path := filepath.Join("server", "config", "server.yaml") - b, err := os.ReadFile(path) - if err != nil { - return nil, err - } - var cfg struct { - RAG Runtime `yaml:"RAG"` - } - if err := yaml.Unmarshal(b, &cfg); err != nil { - return nil, err - } - return &cfg.RAG, nil -} - -// ToConfig converts runtime configuration into service configuration. -func (rt *Runtime) ToConfig() *Config { - if rt == nil { - return nil - } - var c Config - for _, ds := range rt.Datasources { - c.Repos = append(c.Repos, Repo{ - URL: ds.Repo, - Paths: []string{ds.Path}, - Local: filepath.Join("server", "rag", ds.Name), - }) - } - return &c +// ResolveChunking returns chunking configuration with defaults applied. +func (c *Config) ResolveChunking() ChunkingCfg { + ch := c.Chunking + if ch.MaxTokens == 0 { + ch.MaxTokens = 800 + } + if ch.OverlapTokens == 0 { + ch.OverlapTokens = 80 + } + if len(ch.IncludeExts) == 0 { + ch.IncludeExts = []string{".md", ".mdx"} + } + if len(ch.IgnoreDirs) == 0 { + ch.IgnoreDirs = []string{".git", "node_modules", "dist", "build"} + } + return ch } diff --git a/server/rag/config/runtime_test.go b/server/rag/config/runtime_test.go index 0d718e4..d69a4b1 100644 --- a/server/rag/config/runtime_test.go +++ b/server/rag/config/runtime_test.go @@ -1,9 +1,6 @@ package config -import ( - "path/filepath" - "testing" -) +import "testing" func TestVectorDB_DSN(t *testing.T) { v := VectorDB{ @@ -21,22 +18,30 @@ func TestVectorDB_DSN(t *testing.T) { } } -func TestRuntimeToConfig(t *testing.T) { - rt := &Runtime{ - Datasources: []Datasource{ - {Name: "docs", Repo: "https://example.com/repo.git", Path: "docs"}, - }, +func TestResolveEmbedding(t *testing.T) { + cfg := &Config{ + Provider: []Provider{{Name: "p1", BaseURL: "https://api.example.com", Token: "tok"}}, + Embedding: EmbeddingCfg{Provider: "p1", Model: "m"}, } - cfg := rt.ToConfig() - if len(cfg.Repos) != 1 { - t.Fatalf("expected 1 repo, got %d", len(cfg.Repos)) + e := cfg.ResolveEmbedding() + if e.BaseURL != "https://api.example.com/v1" { + t.Fatalf("unexpected base url %q", e.BaseURL) } - r := cfg.Repos[0] - if r.URL != "https://example.com/repo.git" || len(r.Paths) != 1 || r.Paths[0] != "docs" { - t.Fatalf("unexpected repo: %+v", r) + if e.APIKey != "tok" { + t.Fatalf("unexpected api key %q", e.APIKey) } - expectedLocal := filepath.Join("server", "rag", "docs") - if r.Local != expectedLocal { - t.Fatalf("expected local %q, got %q", expectedLocal, r.Local) + if e.Model != "m" { + t.Fatalf("unexpected model %q", e.Model) + } +} + +func TestResolveChunking(t *testing.T) { + cfg := &Config{} + ch := cfg.ResolveChunking() + if ch.MaxTokens != 800 || ch.OverlapTokens != 80 { + t.Fatalf("defaults not applied: %+v", ch) + } + if len(ch.IncludeExts) == 0 || len(ch.IgnoreDirs) == 0 { + t.Fatalf("expected default slices") } } diff --git a/server/rag/embed/embed.go b/server/rag/embed/embed.go index aae0139..2a0c6c2 100644 --- a/server/rag/embed/embed.go +++ b/server/rag/embed/embed.go @@ -2,7 +2,8 @@ package embed import "context" -// Embedder produces a vector representation for input text. +// Embedder defines embedding operations. type Embedder interface { - Embed(ctx context.Context, text string) ([]float32, error) + Embed(ctx context.Context, inputs []string) ([][]float32, int, error) + Dimension() int } diff --git a/server/rag/embed/openai.go b/server/rag/embed/openai.go index 0652e1b..af9de34 100644 --- a/server/rag/embed/openai.go +++ b/server/rag/embed/openai.go @@ -4,46 +4,76 @@ import ( "bytes" "context" "encoding/json" + "errors" + "fmt" "net/http" + "time" ) -// OpenAI calls the OpenAI embeddings endpoint. +// OpenAI implements the Embedder interface using OpenAI-compatible APIs. type OpenAI struct { - APIKey string - Model string - Client *http.Client + baseURL string + apiKey string + model string + dim int + client *http.Client } -// NewOpenAI creates a new OpenAI embedder. -func NewOpenAI(model, key string) *OpenAI { - return &OpenAI{Model: model, APIKey: key, Client: &http.Client{}} -} - -// Embed generates an embedding using OpenAI API. -func (o *OpenAI) Embed(ctx context.Context, text string) ([]float32, error) { - body := map[string]any{"input": text, "model": o.Model} - b, _ := json.Marshal(body) - req, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/embeddings", bytes.NewReader(b)) - if err != nil { - return nil, err +// NewOpenAI creates a new OpenAI embedder from configuration. +func NewOpenAI(baseURL, apiKey, model string, dim int) *OpenAI { + return &OpenAI{ + baseURL: baseURL, + apiKey: apiKey, + model: model, + dim: dim, + client: &http.Client{Timeout: 30 * time.Second}, } - req.Header.Set("Authorization", "Bearer "+o.APIKey) - req.Header.Set("Content-Type", "application/json") - resp, err := o.Client.Do(req) +} + +// Dimension returns the embedding dimension if known. +func (o *OpenAI) Dimension() int { return o.dim } + +// Embed embeds the inputs and returns vectors and token usage. +func (o *OpenAI) Embed(ctx context.Context, inputs []string) ([][]float32, int, error) { + payload := map[string]any{ + "model": o.model, + "input": inputs, + } + b, _ := json.Marshal(payload) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, o.baseURL+"/embeddings", bytes.NewReader(b)) if err != nil { - return nil, err + return nil, 0, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+o.apiKey) + resp, err := o.client.Do(req) + if err != nil { + return nil, 0, err } defer resp.Body.Close() - var res struct { + if resp.StatusCode >= 300 { + return nil, 0, fmt.Errorf("embed failed: %s", resp.Status) + } + var out struct { Data []struct { Embedding []float32 `json:"embedding"` } `json:"data"` + Usage struct { + TotalTokens int `json:"total_tokens"` + } `json:"usage"` } - if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { - return nil, err + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, 0, err } - if len(res.Data) == 0 { - return nil, nil + if len(out.Data) != len(inputs) { + return nil, 0, errors.New("embedding count mismatch") } - return res.Data[0].Embedding, nil + if o.dim == 0 && len(out.Data) > 0 { + o.dim = len(out.Data[0].Embedding) + } + vecs := make([][]float32, len(out.Data)) + for i, d := range out.Data { + vecs[i] = d.Embedding + } + return vecs, out.Usage.TotalTokens, nil } diff --git a/server/rag/ingest/chunk.go b/server/rag/ingest/chunk.go new file mode 100644 index 0000000..8d13fae --- /dev/null +++ b/server/rag/ingest/chunk.go @@ -0,0 +1,80 @@ +package ingest + +import ( + "strings" + + cfgpkg "xcontrol/server/rag/config" +) + +// Chunk represents a piece of text prepared for embedding. +type Chunk struct { + ChunkID int + Text string + Tokens int + SHA256 string + Meta map[string]any +} + +// BuildChunks splits sections into chunks based on configuration. +// Token counting uses a best-effort approach by words when tiktoken fails. +func BuildChunks(secs []Section, cfg cfgpkg.ChunkingCfg) ([]Chunk, error) { + var chunks []Chunk + nextID := 0 + for _, sec := range secs { + tokens := tokenize(sec.Text) + if len(tokens) == 0 { + continue + } + step := cfg.MaxTokens + if step <= 0 { + step = 800 + } + overlap := cfg.OverlapTokens + if overlap < 0 { + overlap = 0 + } + if len(tokens) <= step { + text := strings.TrimSpace(sec.Text) + chunks = append(chunks, Chunk{ + ChunkID: nextID, + Text: text, + Tokens: len(tokens), + SHA256: HashString(text), + Meta: map[string]any{"heading": sec.Heading}, + }) + nextID++ + continue + } + start := 0 + for start < len(tokens) { + end := start + step + if end > len(tokens) { + end = len(tokens) + } + sub := strings.Join(tokens[start:end], " ") + chunks = append(chunks, Chunk{ + ChunkID: nextID, + Text: sub, + Tokens: end - start, + SHA256: HashString(sub), + Meta: map[string]any{"heading": sec.Heading}, + }) + nextID++ + if end == len(tokens) { + break + } + start = end - overlap + if start < 0 { + start = 0 + } + } + } + return chunks, nil +} + +func tokenize(s string) []string { + if s == "" { + return nil + } + return strings.Fields(s) +} diff --git a/server/rag/ingest/chunk_test.go b/server/rag/ingest/chunk_test.go new file mode 100644 index 0000000..24ea9e1 --- /dev/null +++ b/server/rag/ingest/chunk_test.go @@ -0,0 +1,52 @@ +package ingest + +import ( + "strings" + "testing" + + cfgpkg "xcontrol/server/rag/config" +) + +func TestBuildChunksHeading(t *testing.T) { + secs := []Section{{Heading: "h", Text: "a b c"}} + cfg := cfgpkg.ChunkingCfg{MaxTokens: 10, OverlapTokens: 2} + chunks, err := BuildChunks(secs, cfg) + if err != nil { + t.Fatalf("build: %v", err) + } + if len(chunks) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(chunks)) + } + if chunks[0].Meta["heading"].(string) != "h" { + t.Fatalf("heading mismatch") + } +} + +func TestBuildChunksSlidingWindow(t *testing.T) { + text := "one two three four five six seven eight nine ten" + secs := []Section{{Heading: "h", Text: text}} + cfg := cfgpkg.ChunkingCfg{MaxTokens: 4, OverlapTokens: 1} + chunks, err := BuildChunks(secs, cfg) + if err != nil { + t.Fatalf("build: %v", err) + } + if len(chunks) != 3 { + t.Fatalf("expected 3 chunks, got %d", len(chunks)) + } +} + +func TestBuildChunksOverlap(t *testing.T) { + text := "a b c d e f" + secs := []Section{{Heading: "h", Text: text}} + cfg := cfgpkg.ChunkingCfg{MaxTokens: 3, OverlapTokens: 1} + chunks, err := BuildChunks(secs, cfg) + if err != nil { + t.Fatalf("build: %v", err) + } + if len(chunks) != 3 { + t.Fatalf("expected 3 chunks, got %d", len(chunks)) + } + if !strings.Contains(chunks[1].Text, "c") { + t.Fatalf("expected overlap token in second chunk") + } +} diff --git a/server/rag/ingest/hash.go b/server/rag/ingest/hash.go new file mode 100644 index 0000000..914fc75 --- /dev/null +++ b/server/rag/ingest/hash.go @@ -0,0 +1,28 @@ +package ingest + +import ( + "crypto/sha256" + "encoding/hex" + "io" + "os" +) + +// HashString returns the SHA256 hex digest of the provided string. +func HashString(s string) string { + sum := sha256.Sum256([]byte(s)) + return hex.EncodeToString(sum[:]) +} + +// HashFile computes the SHA256 hash of the file at path. +func HashFile(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil +} diff --git a/server/rag/ingest/hash_test.go b/server/rag/ingest/hash_test.go new file mode 100644 index 0000000..8076fa4 --- /dev/null +++ b/server/rag/ingest/hash_test.go @@ -0,0 +1,34 @@ +package ingest + +import ( + "os" + "testing" +) + +func TestHashString(t *testing.T) { + h1 := HashString("hello") + h2 := HashString("hello") + if h1 != h2 { + t.Fatalf("hashes differ: %s vs %s", h1, h2) + } +} + +func TestHashFile(t *testing.T) { + f, err := os.CreateTemp(t.TempDir(), "hash") + if err != nil { + t.Fatalf("temp: %v", err) + } + f.WriteString("content") + f.Close() + h1, err := HashFile(f.Name()) + if err != nil { + t.Fatalf("hash1: %v", err) + } + h2, err := HashFile(f.Name()) + if err != nil { + t.Fatalf("hash2: %v", err) + } + if h1 != h2 { + t.Fatalf("file hashes differ") + } +} diff --git a/server/rag/ingest/ingest.go b/server/rag/ingest/ingest.go index b5a104b..82d0b33 100644 --- a/server/rag/ingest/ingest.go +++ b/server/rag/ingest/ingest.go @@ -1,43 +1,123 @@ package ingest import ( - "bytes" - "os" + "context" + "path/filepath" "strings" + "time" - "github.com/yuin/goldmark" + "github.com/jackc/pgx/v5" + + cfgpkg "xcontrol/server/rag/config" + "xcontrol/server/rag/embed" "xcontrol/server/rag/store" + rsync "xcontrol/server/rag/sync" ) -const chunkSize = 800 - -// File reads a markdown file and returns chunked documents. -func File(repo, path string) ([]store.Document, error) { - b, err := os.ReadFile(path) - if err != nil { - return nil, err - } - var buf bytes.Buffer - if err := goldmark.Convert(b, &buf); err != nil { - return nil, err - } - words := strings.Fields(buf.String()) - var docs []store.Document - for i := 0; i < len(words); i += chunkSize { - end := i + chunkSize - if end > len(words) { - end = len(words) - } - chunk := strings.Join(words[i:end], " ") - docs = append(docs, store.Document{ - Repo: repo, - Path: path, - ChunkID: len(docs), - Content: chunk, - Metadata: map[string]any{ - "offset": i, - }, - }) - } - return docs, nil +// Options control ingestion behaviour. +type Options struct { + MaxFiles int + DryRun bool + MigrateDim bool + Concurrency int +} + +// Stats captures pipeline statistics. +type Stats struct { + FilesScanned, FilesSkipped int + ChunksBuilt, ChunksSkipped int + EmbeddingsCreated int + RowsUpserted int + TokensEstimated int + Elapsed time.Duration + Errors []error +} + +// IngestRepo performs the full ingestion pipeline for a datasource. +func IngestRepo(ctx context.Context, cfg *cfgpkg.Config, ds cfgpkg.DataSource, opt Options) (Stats, error) { + start := time.Now() + var st Stats + + chunkCfg := cfg.ResolveChunking() + embCfg := cfg.ResolveEmbedding() + + workdir := filepath.Join("server", "rag", ds.Name) + if _, err := rsync.SyncRepo(ctx, ds.Repo, workdir); err != nil { + st.Errors = append(st.Errors, err) + return st, err + } + + root := filepath.Join(workdir, ds.Path) + files, err := ListMarkdown(root, chunkCfg.IncludeExts, chunkCfg.IgnoreDirs, opt.MaxFiles) + if err != nil { + st.Errors = append(st.Errors, err) + return st, err + } + st.FilesScanned = len(files) + + dsn := cfg.Global.VectorDB.DSN() + conn, err := pgx.Connect(ctx, dsn) + if err != nil { + st.Errors = append(st.Errors, err) + return st, err + } + defer conn.Close(ctx) + + embedder := embed.NewOpenAI(embCfg.BaseURL, embCfg.APIKey, embCfg.Model, embCfg.Dimension) + if err := store.EnsureSchema(ctx, conn, embedder.Dimension(), opt.MigrateDim); err != nil { + st.Errors = append(st.Errors, err) + return st, err + } + + for _, f := range files { + secs, err := ParseMarkdown(f) + if err != nil { + st.Errors = append(st.Errors, err) + continue + } + chunks, err := BuildChunks(secs, chunkCfg) + if err != nil { + st.Errors = append(st.Errors, err) + continue + } + if len(chunks) == 0 { + continue + } + st.ChunksBuilt += len(chunks) + texts := make([]string, len(chunks)) + rows := make([]store.DocRow, len(chunks)) + for i, ch := range chunks { + texts[i] = ch.Text + rows[i] = store.DocRow{ + Repo: ds.Repo, + Path: strings.TrimPrefix(f, workdir+"/"), + ChunkID: ch.ChunkID, + Content: ch.Text, + Metadata: ch.Meta, + ContentSHA: ch.SHA256, + } + } + vecs, tokens, err := embedder.Embed(ctx, texts) + if err != nil { + st.Errors = append(st.Errors, err) + continue + } + st.EmbeddingsCreated += len(vecs) + st.TokensEstimated += tokens + for i := range rows { + rows[i].Embedding = vecs[i] + } + if opt.DryRun { + continue + } + n, err := store.UpsertDocuments(ctx, conn, rows) + if err != nil { + st.Errors = append(st.Errors, err) + continue + } + st.RowsUpserted += n + } + + st.Elapsed = time.Since(start) + return st, nil } diff --git a/server/rag/ingest/markdown.go b/server/rag/ingest/markdown.go new file mode 100644 index 0000000..9e73ac3 --- /dev/null +++ b/server/rag/ingest/markdown.go @@ -0,0 +1,71 @@ +package ingest + +import ( + "bytes" + "os" + "strings" + + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/text" +) + +// Section represents a portion of a markdown document grouped by heading. +type Section struct { + Heading string + Text string +} + +// ParseMarkdown parses a markdown file into sections. It normalizes whitespace +// and preserves code fences. +func ParseMarkdown(path string) ([]Section, error) { + b, err := os.ReadFile(path) + if err != nil { + return nil, err + } + md := goldmark.New() + doc := md.Parser().Parse(text.NewReader(b)) + + var secs []Section + var cur *Section + var buf bytes.Buffer + + ast.Walk(doc, func(n ast.Node, entering bool) (ast.WalkStatus, error) { + switch node := n.(type) { + case *ast.Heading: + if entering { + if cur != nil { + cur.Text = strings.TrimSpace(buf.String()) + secs = append(secs, *cur) + buf.Reset() + } + cur = &Section{Heading: string(node.Text(b))} + } + return ast.WalkContinue, nil + case *ast.CodeBlock: + if entering { + buf.WriteString("```\n") + for i := 0; i < node.Lines().Len(); i++ { + line := node.Lines().At(i) + buf.Write(line.Value(b)) + } + buf.WriteString("\n```\n") + } + return ast.WalkSkipChildren, nil + case *ast.Text: + if entering { + buf.Write(node.Segment.Value(b)) + if node.HardLineBreak() || node.SoftLineBreak() { + buf.WriteByte('\n') + } + } + } + return ast.WalkContinue, nil + }) + + if cur != nil { + cur.Text = strings.TrimSpace(buf.String()) + secs = append(secs, *cur) + } + return secs, nil +} diff --git a/server/rag/ingest/walk.go b/server/rag/ingest/walk.go new file mode 100644 index 0000000..c075518 --- /dev/null +++ b/server/rag/ingest/walk.go @@ -0,0 +1,46 @@ +package ingest + +import ( + "io" + "os" + "path/filepath" + "strings" +) + +// ListMarkdown walks root and returns markdown files respecting include and ignore lists. +// If maxFiles > 0 the result is limited to at most that many paths. +func ListMarkdown(root string, includeExts, ignoreDirs []string, maxFiles int) ([]string, error) { + var files []string + include := make(map[string]struct{}) + for _, e := range includeExts { + include[strings.ToLower(e)] = struct{}{} + } + ignores := make(map[string]struct{}) + for _, d := range ignoreDirs { + ignores[d] = struct{}{} + } + + err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + if d.IsDir() { + if _, ok := ignores[d.Name()]; ok { + return filepath.SkipDir + } + return nil + } + if maxFiles > 0 && len(files) >= maxFiles { + return io.EOF + } + ext := strings.ToLower(filepath.Ext(path)) + if _, ok := include[ext]; ok { + files = append(files, path) + } + return nil + }) + if err != nil && err != io.EOF { + return nil, err + } + return files, nil +} diff --git a/server/rag/rag.go b/server/rag/rag.go index 6038278..aafd584 100644 --- a/server/rag/rag.go +++ b/server/rag/rag.go @@ -1,107 +1,4 @@ package rag -import ( - "context" - "time" - - "xcontrol/server/rag/config" - "xcontrol/server/rag/embed" - "xcontrol/server/rag/ingest" - "xcontrol/server/rag/store" - rsync "xcontrol/server/rag/sync" -) - -// Service provides high level RAG operations for syncing data sources -// and querying the vector store. -type Service struct { - cfg *config.Config - st *store.Store - emb embed.Embedder -} - -// New creates a new Service using the provided configuration, -// storage and embedder. Any of the arguments may be nil if the -// corresponding feature is not required by the caller. -func New(cfg *config.Config, st *store.Store, emb embed.Embedder) *Service { - return &Service{cfg: cfg, st: st, emb: emb} -} - -// Sync clones or updates configured repositories, ingests markdown files -// and upserts their embeddings into the store. -func (s *Service) Sync(ctx context.Context) error { - if s == nil || s.cfg == nil || s.st == nil || s.emb == nil { - return nil - } - for _, repo := range s.cfg.Repos { - if err := s.syncRepo(ctx, repo, true); err != nil { - return err - } - } - return nil -} - -// syncRepo pulls markdown files for a single repo and ingests them. -// If force is false, ingestion is skipped when the repo has no changes. -func (s *Service) syncRepo(ctx context.Context, repo config.Repo, force bool) error { - if s == nil || s.st == nil || s.emb == nil { - return nil - } - files, changed, err := rsync.Repo(repo) - if err != nil { - return err - } - if !force && !changed { - return nil - } - for _, f := range files { - docs, err := ingest.File(repo.URL, f) - if err != nil { - continue - } - for i := range docs { - vec, err := s.emb.Embed(ctx, docs[i].Content) - if err != nil { - continue - } - docs[i].Embedding = vec - } - if err := s.st.Upsert(ctx, docs); err != nil { - return err - } - } - return nil -} - -// Watch monitors configured repositories and triggers sync on updates. -func (s *Service) Watch(ctx context.Context) { - if s == nil || s.cfg == nil || s.st == nil || s.emb == nil { - return - } - for _, repo := range s.cfg.Repos { - go func(r config.Repo) { - ticker := time.NewTicker(time.Minute) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - _ = s.syncRepo(ctx, r, false) - } - } - }(repo) - } -} - -// Query embeds the question and searches the store for similar documents. -// If the service is not fully configured, Query returns nil without error. -func (s *Service) Query(ctx context.Context, question string, limit int) ([]store.Document, error) { - if s == nil || s.st == nil || s.emb == nil { - return nil, nil - } - vec, err := s.emb.Embed(ctx, question) - if err != nil { - return nil, err - } - return s.st.Search(ctx, vec, limit) -} +// Package rag provides RAG related helpers. The main ingestion pipeline lives +// under the ingest subpackage. diff --git a/server/rag/rag_legacy.go b/server/rag/rag_legacy.go new file mode 100644 index 0000000..1d47063 --- /dev/null +++ b/server/rag/rag_legacy.go @@ -0,0 +1,109 @@ +//go:build legacy + +package rag + +import ( + "context" + "time" + + "xcontrol/server/rag/config" + "xcontrol/server/rag/embed" + "xcontrol/server/rag/ingest" + "xcontrol/server/rag/store" + rsync "xcontrol/server/rag/sync" +) + +// Service provides high level RAG operations for syncing data sources +// and querying the vector store. +type Service struct { + cfg *config.Config + st *store.Store + emb embed.Embedder +} + +// New creates a new Service using the provided configuration, +// storage and embedder. Any of the arguments may be nil if the +// corresponding feature is not required by the caller. +func New(cfg *config.Config, st *store.Store, emb embed.Embedder) *Service { + return &Service{cfg: cfg, st: st, emb: emb} +} + +// Sync clones or updates configured repositories, ingests markdown files +// and upserts their embeddings into the store. +func (s *Service) Sync(ctx context.Context) error { + if s == nil || s.cfg == nil || s.st == nil || s.emb == nil { + return nil + } + for _, repo := range s.cfg.Repos { + if err := s.syncRepo(ctx, repo, true); err != nil { + return err + } + } + return nil +} + +// syncRepo pulls markdown files for a single repo and ingests them. +// If force is false, ingestion is skipped when the repo has no changes. +func (s *Service) syncRepo(ctx context.Context, repo config.Repo, force bool) error { + if s == nil || s.st == nil || s.emb == nil { + return nil + } + files, changed, err := rsync.Repo(repo) + if err != nil { + return err + } + if !force && !changed { + return nil + } + for _, f := range files { + docs, err := ingest.File(repo.URL, f) + if err != nil { + continue + } + for i := range docs { + vec, err := s.emb.Embed(ctx, docs[i].Content) + if err != nil { + continue + } + docs[i].Embedding = vec + } + if err := s.st.Upsert(ctx, docs); err != nil { + return err + } + } + return nil +} + +// Watch monitors configured repositories and triggers sync on updates. +func (s *Service) Watch(ctx context.Context) { + if s == nil || s.cfg == nil || s.st == nil || s.emb == nil { + return + } + for _, repo := range s.cfg.Repos { + go func(r config.Repo) { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + _ = s.syncRepo(ctx, r, false) + } + } + }(repo) + } +} + +// Query embeds the question and searches the store for similar documents. +// If the service is not fully configured, Query returns nil without error. +func (s *Service) Query(ctx context.Context, question string, limit int) ([]store.Document, error) { + if s == nil || s.st == nil || s.emb == nil { + return nil, nil + } + vec, err := s.emb.Embed(ctx, question) + if err != nil { + return nil, err + } + return s.st.Search(ctx, vec, limit) +} diff --git a/server/rag/store/store.go b/server/rag/store/store.go index 01dcb28..e8380a9 100644 --- a/server/rag/store/store.go +++ b/server/rag/store/store.go @@ -3,77 +3,91 @@ package store import ( "context" "encoding/json" + "fmt" - "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5" + pgvector "github.com/pgvector/pgvector-go" ) -// Document represents a chunk stored in Postgres. -type Document struct { - ID int64 - Repo string - Path string - ChunkID int - Content string - Embedding []float32 - Metadata map[string]any +// DocRow represents a row to be stored in the documents table. +type DocRow struct { + Repo string + Path string + ChunkID int + Content string + Embedding []float32 + Metadata map[string]any + ContentSHA string } -// Store wraps a pgx pool for vector operations. -type Store struct { - pool *pgxpool.Pool -} - -// New creates a new Store connected using dsn. -func New(ctx context.Context, dsn string) (*Store, error) { - p, err := pgxpool.New(ctx, dsn) - if err != nil { - return nil, err +// EnsureSchema creates the documents table and indexes if they do not exist. It +// also validates the embedding dimension. When migrate is true and a dimension +// mismatch is detected, it attempts to alter the column type. +func EnsureSchema(ctx context.Context, conn *pgx.Conn, dim int, migrate bool) error { + // ensure table + create := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS documents ( + id BIGSERIAL PRIMARY KEY, + repo TEXT, + path TEXT, + chunk_id INT, + content TEXT, + embedding VECTOR(%d), + metadata JSONB, + content_sha TEXT NOT NULL, + updated_at TIMESTAMPTZ DEFAULT now(), + UNIQUE(repo,path,chunk_id) + )`, dim) + if _, err := conn.Exec(ctx, create); err != nil { + return err } - return &Store{pool: p}, nil -} - -// Upsert writes documents and their embeddings to the database. -func (s *Store) Upsert(ctx context.Context, docs []Document) error { - for _, d := range docs { - meta, _ := json.Marshal(d.Metadata) - _, err := s.pool.Exec(ctx, - `INSERT INTO documents (repo,path,chunk_id,content,embedding,metadata) - VALUES ($1,$2,$3,$4,$5,$6) - ON CONFLICT (repo,path,chunk_id) DO UPDATE - SET content=EXCLUDED.content, - embedding=EXCLUDED.embedding, - metadata=EXCLUDED.metadata`, - d.Repo, d.Path, d.ChunkID, d.Content, d.Embedding, meta, - ) - if err != nil { + // check dimension + var curDim int + err := conn.QueryRow(ctx, `SELECT atttypmod-4 FROM pg_attribute a JOIN pg_type t ON a.atttypid=t.oid WHERE a.attrelid='documents'::regclass AND a.attname='embedding'`).Scan(&curDim) + if err == nil && curDim != dim { + if !migrate { + return fmt.Errorf("embedding dimension %d != %d", curDim, dim) + } + if _, err := conn.Exec(ctx, `DROP INDEX IF EXISTS documents_embedding_idx`); err != nil { return err } + if _, err := conn.Exec(ctx, fmt.Sprintf(`ALTER TABLE documents ALTER COLUMN embedding TYPE VECTOR(%d)`, dim)); err != nil { + return err + } + } + // index + if _, err := conn.Exec(ctx, `CREATE INDEX IF NOT EXISTS documents_embedding_idx ON documents USING hnsw (embedding vector_cosine_ops)`); err != nil { + return err } return nil } -// Search returns top similar documents ordered by cosine distance. -func (s *Store) Search(ctx context.Context, vec []float32, limit int) ([]Document, error) { - rows, err := s.pool.Query(ctx, - `SELECT repo,path,chunk_id,content,metadata - FROM documents - ORDER BY embedding <-> $1 - LIMIT $2`, vec, limit, - ) - if err != nil { - return nil, err +// UpsertDocuments upserts rows and returns affected row count. +func UpsertDocuments(ctx context.Context, conn *pgx.Conn, rows []DocRow) (int, error) { + if len(rows) == 0 { + return 0, nil } - defer rows.Close() - - var res []Document - for rows.Next() { - var d Document - var meta []byte - if err := rows.Scan(&d.Repo, &d.Path, &d.ChunkID, &d.Content, &meta); err != nil { - return nil, err + batch := &pgx.Batch{} + for _, r := range rows { + meta, _ := json.Marshal(r.Metadata) + batch.Queue(`INSERT INTO documents (repo,path,chunk_id,content,embedding,metadata,content_sha) + VALUES ($1,$2,$3,$4,$5,$6,$7) + ON CONFLICT (repo,path,chunk_id) DO UPDATE + SET content=EXCLUDED.content, + embedding=EXCLUDED.embedding, + metadata=EXCLUDED.metadata, + content_sha=EXCLUDED.content_sha + WHERE documents.content_sha<>EXCLUDED.content_sha`, + r.Repo, r.Path, r.ChunkID, r.Content, pgvector.NewVector(r.Embedding), meta, r.ContentSHA) + } + br := conn.SendBatch(ctx, batch) + count := 0 + for range rows { + ct, err := br.Exec() + if err != nil { + br.Close() + return count, err } - json.Unmarshal(meta, &d.Metadata) - res = append(res, d) + count += int(ct.RowsAffected()) } - return res, rows.Err() + return count, br.Close() } diff --git a/server/rag/sync/sync.go b/server/rag/sync/sync.go index 1ffb724..a971b77 100644 --- a/server/rag/sync/sync.go +++ b/server/rag/sync/sync.go @@ -1,66 +1,74 @@ package sync import ( - "errors" - "io/fs" - "path/filepath" + "context" + "os" - git "github.com/go-git/go-git/v5" - "github.com/go-git/go-git/v5/plumbing" - "xcontrol/server/rag/config" + git "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing" + "github.com/go-git/go-git/v5/plumbing/transport/http" ) -// Repo synchronizes the configured repository and returns markdown file paths. -// The returned boolean indicates whether new commits were pulled. -func Repo(c config.Repo) ([]string, bool, error) { - changed := false - if _, err := git.PlainOpen(c.Local); err != nil { - opts := &git.CloneOptions{URL: c.URL} - if c.Branch != "" { - opts.ReferenceName = plumbing.NewBranchReferenceName(c.Branch) - opts.SingleBranch = true - } - if _, err := git.PlainClone(c.Local, false, opts); err != nil { - return nil, false, err - } - changed = true - } else { - r, err := git.PlainOpen(c.Local) - if err != nil { - return nil, false, err - } - w, err := r.Worktree() - if err != nil { - return nil, false, err - } - pullOpts := &git.PullOptions{RemoteName: "origin"} - if c.Branch != "" { - pullOpts.ReferenceName = plumbing.NewBranchReferenceName(c.Branch) - pullOpts.SingleBranch = true - } - if err := w.Pull(pullOpts); err != nil { - if !errors.Is(err, git.NoErrAlreadyUpToDate) { - return nil, false, err - } - } else { - changed = true - } - } - var files []string - for _, p := range c.Paths { - root := filepath.Join(c.Local, p) - filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return nil - } - if d.IsDir() { - return nil - } - if filepath.Ext(path) == ".md" { - files = append(files, path) - } - return nil +// SyncRepo ensures the repository at workdir matches the remote url. It +// performs a shallow clone when the directory does not exist, otherwise a +// fetch and reset. The returned string is the current HEAD commit hash. +func SyncRepo(ctx context.Context, url, workdir string) (string, error) { + if _, err := os.Stat(workdir); os.IsNotExist(err) { + // shallow clone + _, err := git.PlainCloneContext(ctx, workdir, false, &git.CloneOptions{ + URL: url, + Depth: 1, }) + if err != nil { + return "", err + } + } else { + r, err := git.PlainOpen(workdir) + if err != nil { + return "", err + } + // fetch + if err := r.FetchContext(ctx, &git.FetchOptions{Depth: 1}); err != nil && err != git.NoErrAlreadyUpToDate { + return "", err + } + // reset to origin/HEAD + head, err := r.ResolveRevision(plumbing.Revision("origin/HEAD")) + if err != nil { + // fallback to master/main + head, err = r.ResolveRevision(plumbing.Revision("origin/main")) + if err != nil { + head, err = r.ResolveRevision(plumbing.Revision("origin/master")) + if err != nil { + return "", err + } + } + } + w, err := r.Worktree() + if err != nil { + return "", err + } + if err := w.Reset(&git.ResetOptions{Mode: git.HardReset, Commit: *head}); err != nil { + return "", err + } } - return files, changed, nil + + r, err := git.PlainOpen(workdir) + if err != nil { + return "", err + } + ref, err := r.Head() + if err != nil { + return "", err + } + return ref.Hash().String(), nil +} + +// WithAuth returns CloneOptions with basic auth if username/token provided in URL. +// This is a helper for future extension; currently unused. +func WithAuth(url, token string) *git.CloneOptions { + opts := &git.CloneOptions{URL: url, Depth: 1} + if token != "" { + opts.Auth = &http.BasicAuth{Username: "token", Password: token} + } + return opts }