Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/k8shell-io/ssh-proxy
go 1.24.5

require (
github.com/k8shell-io/common v0.29.4
github.com/k8shell-io/common v0.30.10
github.com/nats-io/nats.go v1.47.0
github.com/rs/zerolog v1.34.0
golang.org/x/crypto v0.43.0
Expand All @@ -20,6 +20,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.27.0 // indirect
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
Expand Down
52 changes: 10 additions & 42 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -31,48 +31,16 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/k8shell-io/common v0.20.20 h1:4Tze71ObbtFT6l45wwhDGw8u35w0gTRzuAvH55bEsJE=
github.com/k8shell-io/common v0.20.20/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.21.0 h1:EOwaQOFnHQJsHcDLDVEAwNnJJe7uvQCqcOiI6BTu3GE=
github.com/k8shell-io/common v0.21.0/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.22.4 h1:zii3NIHOldIrwIldtuwWFFqjEV6JYhbtT50na5pDxLk=
github.com/k8shell-io/common v0.22.4/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.23.0 h1:3v5KjcITfMpdKGVHinHASvV5Y3zAZiBA8y6Br+gusx8=
github.com/k8shell-io/common v0.23.0/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.23.1 h1:Mk0vlt8/J7w/KhYWNS5/lKhk7Dr14ER5UOkYdu4xsP4=
github.com/k8shell-io/common v0.23.1/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.23.2 h1:Gvnt6XzHq9tiAeuhCw64uw5RIUmVQZY57U2X55BLgg8=
github.com/k8shell-io/common v0.23.2/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.23.3 h1:m/1f8x2TfhIqY6hNFBkOthpieCsYtEwac8+QaYarX2c=
github.com/k8shell-io/common v0.23.3/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.23.4 h1:T4VEI0I8efV/Dz2dXmQMUdJ3wOWPy1R/+m/50IEC68A=
github.com/k8shell-io/common v0.23.4/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.23.6 h1:t9ltW+n6pMpxp1q51+7OSoRrM9qMpM108aO6WtVeGTk=
github.com/k8shell-io/common v0.23.6/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.24.0 h1:jnuojNP9uLF2EpxavZPsR7coCkRsczIJUOtxh2uBlNQ=
github.com/k8shell-io/common v0.24.0/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.24.1 h1:NRXjzT9JlHkfsBU4g0tGv24xJbl4Y3NdswFu2FmT8Q8=
github.com/k8shell-io/common v0.24.1/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.24.3 h1:/ni3OBj9WXpe4sPLauIrRzJud5UEAwZOUJxfQDD4yu4=
github.com/k8shell-io/common v0.24.3/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.24.4 h1:clB+fM0K6yZHnSPyGL9kAimPL9g8Blxr0aVwkaEMasg=
github.com/k8shell-io/common v0.24.4/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.24.5 h1:NwgsWDr4S8gLzXVVgVOHOk44m0eO3uCypmjeph+wBJg=
github.com/k8shell-io/common v0.24.5/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.24.6 h1:JbdxoOgYbx8D2zRNO5p/EyRZCrNbRtQj6RGetibW9lk=
github.com/k8shell-io/common v0.24.6/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.25.0 h1:wPVpNT5bjq2teOwWd6J6jIMLP4c5JrDeNCNjXXdIkbg=
github.com/k8shell-io/common v0.25.0/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.25.2 h1:RuxkyxlDhAXmTdOaI2KupZnS/ckQVAXLSat1L9ZmFmw=
github.com/k8shell-io/common v0.25.2/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.26.0 h1:IdqXufdCJOPw+HygsmNYSQlDJiMEJIPeGVAus5BIJw4=
github.com/k8shell-io/common v0.26.0/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.26.1 h1:hJQoq5pHd20pFjHWThnD+VIgPlrv9qXE37P7IH68/m4=
github.com/k8shell-io/common v0.26.1/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.26.2 h1:pGuamWyw+UoZI8Z5gEltwttfgHJiCHmy+2QSOgPWeL0=
github.com/k8shell-io/common v0.26.2/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.29.4 h1:patjuhCWs3g/JVNmx7SVa+wgvu46Z5RYAkD7YY6A4tY=
github.com/k8shell-io/common v0.29.4/go.mod h1:E8dsb9ta4v3ne61AJgtRyTTbTkMMmKeCMAcXD+/9+cY=
github.com/k8shell-io/common v0.30.2 h1:eBQi8lH8mDu3x7V8ja2h/ayOiKlxfkFy9PcuZPLhr9c=
github.com/k8shell-io/common v0.30.2/go.mod h1:40c5GkpS7Y0/aOFa37Lq8z/mLUn3k3GV/AHtFJFL28k=
github.com/k8shell-io/common v0.30.3 h1:VJZxIGUMmm5yWNY/o/yw/J53Owf6zOrTeDNXfLqknZw=
github.com/k8shell-io/common v0.30.3/go.mod h1:40c5GkpS7Y0/aOFa37Lq8z/mLUn3k3GV/AHtFJFL28k=
github.com/k8shell-io/common v0.30.8 h1:GeSrJnu4GBkZ7zzlT0ywIWNl9/+TGbAjFNy6LK5CYqM=
github.com/k8shell-io/common v0.30.8/go.mod h1:40c5GkpS7Y0/aOFa37Lq8z/mLUn3k3GV/AHtFJFL28k=
github.com/k8shell-io/common v0.30.9 h1:urpQL8G1ucAgpqq5TPetDBdX9BgqlJ34bRc9urnIrBw=
github.com/k8shell-io/common v0.30.9/go.mod h1:40c5GkpS7Y0/aOFa37Lq8z/mLUn3k3GV/AHtFJFL28k=
github.com/k8shell-io/common v0.30.10 h1:sIL7pjx38YE/KtRB2VXK+BHL7D+NcM88Lzop/xJfBT8=
github.com/k8shell-io/common v0.30.10/go.mod h1:40c5GkpS7Y0/aOFa37Lq8z/mLUn3k3GV/AHtFJFL28k=
github.com/k8shell-io/crypto v0.41.1-ssh-proxy h1:8+q6Ofc2ky23Oc9iNyiq8aeiQBIP+y3+O6zzHqe1f48=
github.com/k8shell-io/crypto v0.41.1-ssh-proxy/go.mod h1:RVZeOJCpqtogniULztSXQESKJCfcI8WCxsS0FagMA8U=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
Expand Down
141 changes: 99 additions & 42 deletions internal/server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

identityv1 "github.com/k8shell-io/common/pkg/api/gen/go/identity/v1"
"github.com/k8shell-io/common/pkg/authz"
"github.com/k8shell-io/common/pkg/gapi"
"github.com/k8shell-io/common/pkg/models"
"golang.org/x/crypto/ssh"
Expand All @@ -21,13 +22,16 @@ import (

// AllowedAuthsCallback returns the available authentication methods for the user.
func (s *Server) AllowedAuthsCallback(conn ssh.ConnMetadata) ssh.ServerAuthCallbacks {
ctx, cancel := context.WithTimeout(s.ctx, 30*time.Second)
defer cancel()

connInfo, err := s.GetConnInfo(conn)
if err != nil {
s.log.Error().Msgf("Failed to get connection info: %v", err)
return ssh.ServerAuthCallbacks{}
}
s.updateUser(s.ctx, connInfo)
authMethods := s.getAvailableAuthMethods(connInfo)
s.updateUser(ctx, connInfo)
authMethods := s.getAvailableAuthMethods(ctx, connInfo)
if authMethods == nil {
return ssh.ServerAuthCallbacks{}
}
Expand All @@ -45,18 +49,21 @@ func (s *Server) AuthPublicKey(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ss
}
s.updateUser(ctx, connInfo)
if connInfo.user != nil {
if slices.Contains(connInfo.user.Auths, "publickey") {
if s.authPublicKey(connInfo.user, pubKey) {
s.log.Info().Msgf("User %s authenticated with public key", connInfo.user.Username)
return &ssh.Permissions{}, nil
} else {
connInfo.AddFailureInfo("Public key authentication failed", nil)
return nil, fmt.Errorf("public key authentication failed for user %s", connInfo.user.Username)
}
} else {
connInfo.AddFailureInfo("Public key authentication not available", nil)
methods, err := s.resolveAuthMethods(ctx, connInfo)
if err != nil {
connInfo.AddFailureInfo("Failed to resolve authentication methods", err)
return nil, fmt.Errorf("public key authentication not available for user %s", connInfo.user.Username)
}
if !slices.Contains(methods, authz.UserAuthMethodPublicKey) {
connInfo.AddFailureInfo("Public key authentication not permitted by policy", nil)
return nil, fmt.Errorf("public key authentication not available for user %s", connInfo.user.Username)
}
if s.authPublicKey(connInfo.user, pubKey) {
s.log.Info().Msgf("User %s authenticated with public key", connInfo.user.Username)
return &ssh.Permissions{}, nil
}
connInfo.AddFailureInfo("Public key authentication failed", nil)
return nil, fmt.Errorf("public key authentication failed for user %s", connInfo.user.Username)
}

if connInfo.GetOnboardCap() != nil && connInfo.GetOnboardCap().CanOnboard {
Expand All @@ -82,18 +89,21 @@ func (s *Server) AuthPassword(conn ssh.ConnMetadata, password []byte) (*ssh.Perm
}
s.updateUser(ctx, connInfo)
if connInfo.user != nil {
if slices.Contains(connInfo.user.Auths, "password") {
if s.authPassword(connInfo.user) {
s.log.Info().Msgf("User %s authenticated with password", connInfo.user.Username)
return &ssh.Permissions{}, nil
} else {
connInfo.AddFailureInfo("Password authentication failed", nil)
return nil, fmt.Errorf("password authentication failed for user %s", connInfo.user.Username)
}
} else {
connInfo.AddFailureInfo("Password authentication not available", nil)
methods, err := s.resolveAuthMethods(ctx, connInfo)
if err != nil {
connInfo.AddFailureInfo("Failed to resolve authentication methods", err)
return nil, fmt.Errorf("password authentication not available for user %s", connInfo.user.Username)
}
if !slices.Contains(methods, authz.UserAuthMethodPassword) {
connInfo.AddFailureInfo("Password authentication not permitted by policy", nil)
return nil, fmt.Errorf("password authentication not available for user %s", connInfo.user.Username)
}
if s.authPassword(connInfo.user) {
s.log.Info().Msgf("User %s authenticated with password", connInfo.user.Username)
return &ssh.Permissions{}, nil
}
connInfo.AddFailureInfo("Password authentication failed", nil)
return nil, fmt.Errorf("password authentication failed for user %s", connInfo.user.Username)
}

if connInfo.GetOnboardCap() != nil && connInfo.GetOnboardCap().CanOnboard {
Expand Down Expand Up @@ -188,7 +198,7 @@ func (s *Server) checkAuthInteractiveResponse(ctx context.Context,
s.log.Error().Msgf("Failed to complete device flow for user %s: %v", onboardInfo.Username, err)
}
s.log.Info().Msgf("Onboarding completed for user %s", onboardInfo.Username)
return nil, s.getAvailableAuthMethods(auth)
return nil, s.getAvailableAuthMethods(ctx, auth)
}

// AuthPublicKey handles public key authentication via the identity service.
Expand Down Expand Up @@ -260,7 +270,7 @@ func (s *Server) updateUser(ctx context.Context, connInfo *Connection) {

// getAvailableAuthMethods returns the available authentication methods for the user.
// It returns callbacks for the SSH server authentication process.
func (s *Server) getAvailableAuthMethods(connInfo *Connection) *ssh.PartialSuccessError {
func (s *Server) getAvailableAuthMethods(ctx context.Context, connInfo *Connection) *ssh.PartialSuccessError {
if connInfo.user == nil {
onboardCap := connInfo.GetOnboardCap()
if onboardCap != nil && onboardCap.CanOnboard {
Expand All @@ -275,28 +285,75 @@ func (s *Server) getAvailableAuthMethods(connInfo *Connection) *ssh.PartialSucce
return nil
}

methods, err := s.resolveAuthMethods(ctx, connInfo)
if err != nil {
s.log.Error().Msgf("Failed to resolve authentication methods for user %s: %v", connInfo.user.Username, err)
connInfo.AddFailureInfo("Failed to resolve authentication methods", err)
return nil
}

callbacks := ssh.ServerAuthCallbacks{}
for _, authMethod := range connInfo.user.Auths {
switch string(authMethod) {
case "publickey":
s.log.Debug().Msgf("Enabling public key authentication for user %s", connInfo.user.Username)
callbacks.PublicKeyCallback = s.AuthPublicKey
case "password":
s.log.Debug().Msgf("Enabling password authentication for user %s", connInfo.user.Username)
callbacks.PasswordCallback = s.AuthPassword
}
if slices.Contains(methods, authz.UserAuthMethodPublicKey) {
s.log.Debug().Msgf("Enabling public key authentication for user %s", connInfo.user.Username)
callbacks.PublicKeyCallback = s.AuthPublicKey
}
if slices.Contains(methods, authz.UserAuthMethodPassword) {
s.log.Debug().Msgf("Enabling password authentication for user %s", connInfo.user.Username)
callbacks.PasswordCallback = s.AuthPassword
}

if callbacks.PublicKeyCallback != nil ||
callbacks.PasswordCallback != nil ||
callbacks.KeyboardInteractiveCallback != nil {
return &ssh.PartialSuccessError{
Next: callbacks,
}
if callbacks.PublicKeyCallback == nil && callbacks.PasswordCallback == nil {
s.log.Warn().Msgf("No available authentication methods for user %s", connInfo.userStr.Username())
connInfo.AddFailureInfo("No available authentication methods", nil)
return nil
}

return &ssh.PartialSuccessError{
Next: callbacks,
}
}

// resolveAuthMethods returns the SSH authentication methods permitted for the
// connection's user, evaluating the user:auth policy at most once per
// connection (the result is cached on Connection, since it gates both
// advertising in getAvailableAuthMethods and enforcement in
// AuthPublicKey/AuthPassword). When authz is not configured, both methods are
// permitted. When authz is configured but the response carries no
// auth_methods obligation, no methods are permitted, per the user:auth
// contract.
func (s *Server) resolveAuthMethods(ctx context.Context, connInfo *Connection) ([]authz.UserAuthMethod, error) {
if methods, ok := connInfo.GetAuthMethods(); ok {
return methods, nil
}

s.log.Warn().Msgf("No available authentication methods for user %s", connInfo.userStr.Username())
connInfo.AddFailureInfo("No available authentication methods", nil)
if s.authzClient == nil {
methods := []authz.UserAuthMethod{authz.UserAuthMethodPublicKey, authz.UserAuthMethodPassword}
connInfo.SetAuthMethods(methods)
return methods, nil
}

return nil
token, err := connInfo.GetUserToken()
if err != nil {
return nil, fmt.Errorf("failed to get user token for authz check: %w", err)
}

req, err := authz.NewUserAuthEvalRequest(connInfo.user.Username).
WithIDP(connInfo.user.Source).
WithOrg(connInfo.user.Organization).
Build()
if err != nil {
return nil, fmt.Errorf("failed to build user:auth request: %w", err)
}

ob, found, err := s.checkUserAuthMethodsAuthz(ctx, token, req)
if err != nil {
return nil, err
}

methods := []authz.UserAuthMethod{}
if found {
methods = ob.Methods
}
connInfo.SetAuthMethods(methods)
return methods, nil
}
27 changes: 27 additions & 0 deletions internal/server/authzcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,30 @@ func (s *Server) checkSessionAuthz(ctx context.Context, token string, req *authz
ob, found := authz.ParseRecordObligation(resp.GetObligations())
return ob, found, nil
}

// checkUserAuthMethodsAuthz evaluates a user:auth request against the authz
// service and returns the auth_methods obligation naming which SSH
// authentication methods the policy permits for the user. When authz is not
// configured, returns (zero, false, nil) so callers fall back to their own
// default. When authz is configured but the response carries no auth_methods
// obligation, found is false and, per the user:auth contract, the caller must
// offer no authentication methods.
func (s *Server) checkUserAuthMethodsAuthz(ctx context.Context, token string, req *authz.UserAuthEvalRequest) (authz.AuthMethodsObligation, bool, error) {
if s.authzClient == nil {
return authz.AuthMethodsObligation{}, false, nil
}
if err := req.Validate(); err != nil {
return authz.AuthMethodsObligation{}, false, fmt.Errorf("authz: invalid request: %w", err)
}
protoReq := req.ToProto(token)
protoReq.Package = "user"
resp, err := s.authzClient.Evaluate(ctx, protoReq)
if err != nil {
return authz.AuthMethodsObligation{}, false, fmt.Errorf("authz: evaluate user:auth: %w", err)
}
if !resp.GetAllowed() {
return authz.AuthMethodsObligation{}, false, fmt.Errorf("user:auth denied: %s", resp.GetReason())
}
ob, found := authz.ParseAuthMethodsObligation(resp.GetObligations())
return ob, found, nil
}
32 changes: 30 additions & 2 deletions internal/server/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type Connection struct {
reportStopCh chan struct{} // channel to signal report goroutine to stop
reportWg sync.WaitGroup // wait group for report goroutine
ptyName string // name of the allocated pseudo-terminal (if any)
authMethodsMu sync.RWMutex // mutex for synchronizing access to authMethods
authMethods []authz.UserAuthMethod // SSH authentication methods permitted by policy (resolved once per connection)
authMethodsSet bool // whether authMethods has been resolved
}

// Session holds information about a user's SSH session
Expand Down Expand Up @@ -320,7 +323,7 @@ func (c *Connection) updateSession(action string) (bool, error) {
Workspace: c.workspaceName,
BytesIn: curIn,
BytesOut: curOut,
Channels: curChannels,
Operations: curChannels,
UpdatedAt: &t,
Blueprint: c.userStr.Blueprint(),
}
Expand Down Expand Up @@ -365,6 +368,25 @@ func (c *Connection) GetOnboardCap() *models.OnboardCapability {
return c.onboardCap
}

// SetAuthMethods caches the SSH authentication methods permitted by policy
// for this connection, so the user:auth policy is evaluated at most once per
// connection even though it gates both advertising (getAvailableAuthMethods)
// and enforcement (AuthPublicKey/AuthPassword).
func (c *Connection) SetAuthMethods(methods []authz.UserAuthMethod) {
c.authMethodsMu.Lock()
defer c.authMethodsMu.Unlock()
c.authMethods = methods
c.authMethodsSet = true
}

// GetAuthMethods retrieves the cached policy-permitted authentication methods.
// The second return value is false when the methods have not been resolved yet.
func (c *Connection) GetAuthMethods() ([]authz.UserAuthMethod, bool) {
c.authMethodsMu.RLock()
defer c.authMethodsMu.RUnlock()
return c.authMethods, c.authMethodsSet
}

// grpcClientMessage returns a clean single-line message from an error,
// collapsing newlines and extra whitespace.
func grpcClientMessage(err error) string {
Expand Down Expand Up @@ -447,7 +469,13 @@ func (c *Connection) Handshake(writer io.Writer, writerOptions *workspace.InfoWr
c.log.Debug().Msgf("Connecting to k8shelld at %s:%d for user %s, version: %s",
status.ServerName, status.Port, c.user.Username, status.AppVersion)

handshake, err := k8shelld.Handshake(c.ctx, "")
userToken, err := c.GetUserToken()
if err != nil {
infoWriter.WriteSystemError("Failed to obtain access token.")
return nil, fmt.Errorf("failed to get user token for user %s: %w", c.user.Username, err)
}

handshake, err := k8shelld.Handshake(c.ctx, userToken)
if err != nil {
msg := grpcClientMessage(err)
if s, ok := grpcstatus.FromError(err); ok && s.Code() == grpccodes.Unavailable {
Expand Down
Loading
Loading