accounts/rag-server/api/rag.go

102 lines
2.5 KiB
Go

package api
import (
"context"
"errors"
"net/http"
"sync"
"github.com/gin-gonic/gin"
"xcontrol/internal/rag"
rconfig "xcontrol/internal/rag/config"
ragembed "xcontrol/internal/rag/embed"
"xcontrol/internal/rag/store"
"xcontrol/rag-server/proxy"
)
// ragService defines methods used by the RAG API. It allows tests to supply a
// mock implementation without touching the real vector database or embedding
// service.
type ragService interface {
Upsert(ctx context.Context, rows []store.DocRow) (int, error)
Query(ctx context.Context, question string, limit int) ([]rag.Document, error)
}
// ragSvc handles RAG document storage and retrieval. It is initialized lazily
// on demand. ragMu guards concurrent initialization attempts.
var (
ragSvc ragService
ragMu sync.Mutex
)
// initRAG attempts to construct a RAG service from server configuration.
func initRAG() ragService {
cfg, err := rconfig.LoadServer()
if err != nil {
return nil
}
proxy.Set(cfg.Proxy)
return rag.New(cfg.ToConfig())
}
// getRAG returns an initialized ragService, creating it if necessary.
func getRAG() ragService {
ragMu.Lock()
defer ragMu.Unlock()
if ragSvc == nil {
ragSvc = initRAG()
}
return ragSvc
}
// registerRAGRoutes wires the /api/rag upsert and query endpoints.
func registerRAGRoutes(r *gin.RouterGroup) {
r.POST("/rag/upsert", func(c *gin.Context) {
svc := getRAG()
if svc == nil {
c.JSON(http.StatusOK, gin.H{"rows": 0})
return
}
var req struct {
Docs []store.DocRow `json:"docs"`
}
if err := c.BindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
n, err := svc.Upsert(c.Request.Context(), req.Docs)
if err != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"rows": 0, "error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"rows": n})
})
r.POST("/rag/query", func(c *gin.Context) {
var req struct {
Question string `json:"question"`
}
if err := c.BindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
svc := getRAG()
if svc == nil {
c.JSON(http.StatusOK, gin.H{"chunks": nil})
return
}
docs, err := svc.Query(c.Request.Context(), req.Question, 5)
if err != nil {
var httpErr *ragembed.HTTPError
if errors.As(err, &httpErr) {
c.JSON(httpErr.Code, gin.H{"error": httpErr.Error()})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"chunks": docs})
})
}