208 lines
5.9 KiB
Go
208 lines
5.9 KiB
Go
package auth
|
|
|
|
import (
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
|
|
"account/internal/store"
|
|
)
|
|
|
|
// TokenPair represents a pair of Public and Access tokens
|
|
type TokenPair struct {
|
|
PublicToken string `json:"public_token"`
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int64 `json:"expires_in"`
|
|
}
|
|
|
|
// Claims represents JWT access token claims
|
|
type Claims struct {
|
|
UserID string `json:"user_id"`
|
|
Email string `json:"email"`
|
|
Roles []string `json:"roles"`
|
|
MFA bool `json:"mfa_verified"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
// TokenService handles token generation and validation
|
|
type TokenService struct {
|
|
publicToken string
|
|
refreshSecret string
|
|
accessSecret string
|
|
accessExpiry time.Duration
|
|
refreshExpiry time.Duration
|
|
store store.Store
|
|
}
|
|
|
|
// TokenConfig holds configuration for token service
|
|
type TokenConfig struct {
|
|
PublicToken string
|
|
RefreshSecret string
|
|
AccessSecret string
|
|
AccessExpiry time.Duration
|
|
RefreshExpiry time.Duration
|
|
Store store.Store
|
|
}
|
|
|
|
// NewTokenService creates a new TokenService instance
|
|
func NewTokenService(config TokenConfig) *TokenService {
|
|
return &TokenService{
|
|
publicToken: config.PublicToken,
|
|
refreshSecret: config.RefreshSecret,
|
|
accessSecret: config.AccessSecret,
|
|
accessExpiry: config.AccessExpiry,
|
|
refreshExpiry: config.RefreshExpiry,
|
|
store: config.Store,
|
|
}
|
|
}
|
|
|
|
// SetStore sets the store for the token service.
|
|
func (s *TokenService) SetStore(st store.Store) {
|
|
s.store = st
|
|
}
|
|
|
|
// ValidatePublicToken validates the public token
|
|
func (s *TokenService) ValidatePublicToken(publicToken string) bool {
|
|
return publicToken == s.publicToken
|
|
}
|
|
|
|
// GeneratePublicToken returns the configured public token.
|
|
func (s *TokenService) GeneratePublicToken(userID, email string, roles []string) string {
|
|
return s.publicToken
|
|
}
|
|
|
|
// GenerateTokenPair generates a new token pair
|
|
func (s *TokenService) GenerateTokenPair(userID, email string, roles []string) (*TokenPair, error) {
|
|
// Generate refresh token (JWT)
|
|
refreshClaims := jwt.RegisteredClaims{
|
|
Subject: userID,
|
|
Audience: []string{"xcontrol-refresh"},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(s.refreshExpiry)),
|
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
|
Issuer: "xcontrol-account",
|
|
}
|
|
|
|
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
|
refreshTokenString, err := refreshToken.SignedString([]byte(s.refreshSecret))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to sign refresh token: %w", err)
|
|
}
|
|
|
|
// Generate access token (JWT)
|
|
claims := Claims{
|
|
UserID: userID,
|
|
Email: email,
|
|
Roles: roles,
|
|
MFA: true, // Assume MFA is verified for now
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: userID,
|
|
Audience: []string{"xcontrol-access"},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(s.accessExpiry)),
|
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
|
Issuer: "xcontrol-account",
|
|
},
|
|
}
|
|
|
|
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
accessTokenString, err := accessToken.SignedString([]byte(s.accessSecret))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to sign access token: %w", err)
|
|
}
|
|
|
|
return &TokenPair{
|
|
PublicToken: s.publicToken,
|
|
AccessToken: accessTokenString,
|
|
RefreshToken: refreshTokenString,
|
|
TokenType: "Bearer",
|
|
ExpiresIn: int64(s.accessExpiry.Seconds()),
|
|
}, nil
|
|
}
|
|
|
|
// ValidateAccessToken validates and parses an access token
|
|
func (s *TokenService) ValidateAccessToken(accessToken string) (*Claims, error) {
|
|
token, err := jwt.ParseWithClaims(accessToken, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(s.accessSecret), nil
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse access token: %w", err)
|
|
}
|
|
|
|
claims, ok := token.Claims.(*Claims)
|
|
if !ok || !token.Valid {
|
|
return nil, fmt.Errorf("invalid access token")
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// RefreshAccessToken generates a new access token using refresh token
|
|
func (s *TokenService) RefreshAccessToken(refreshToken string) (string, error) {
|
|
token, err := jwt.ParseWithClaims(refreshToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(s.refreshSecret), nil
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to parse refresh token: %w", err)
|
|
}
|
|
|
|
claims, ok := token.Claims.(*jwt.RegisteredClaims)
|
|
if !ok || !token.Valid {
|
|
return "", fmt.Errorf("invalid refresh token")
|
|
}
|
|
|
|
// Verify issuer and audience
|
|
if claims.Issuer != "xcontrol-account" {
|
|
return "", fmt.Errorf("invalid token issuer")
|
|
}
|
|
|
|
if !contains(claims.Audience, "xcontrol-refresh") {
|
|
return "", fmt.Errorf("invalid token audience")
|
|
}
|
|
|
|
// Generate new access token
|
|
newClaims := Claims{
|
|
UserID: claims.Subject,
|
|
Email: "", // Will be populated from user store
|
|
Roles: []string{"user"},
|
|
MFA: true,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: claims.Subject,
|
|
Audience: []string{"xcontrol-access"},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(s.accessExpiry)),
|
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
|
Issuer: "xcontrol-account",
|
|
},
|
|
}
|
|
|
|
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, newClaims)
|
|
accessTokenString, err := accessToken.SignedString([]byte(s.accessSecret))
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign access token: %w", err)
|
|
}
|
|
|
|
return accessTokenString, nil
|
|
}
|
|
|
|
// GetAccessTokenExpiry returns the access token expiry duration
|
|
func (s *TokenService) GetAccessTokenExpiry() time.Duration {
|
|
return s.accessExpiry
|
|
}
|
|
|
|
// Helper function to check if a slice contains a string
|
|
func contains(slice []string, str string) bool {
|
|
for _, s := range slice {
|
|
if s == str {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|