diff --git a/.gitignore b/.gitignore index feb7091..3c94b6f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ talm dist/ +.claude/worktrees/ diff --git a/README.md b/README.md index 750bb33..a7b75a2 100644 --- a/README.md +++ b/README.md @@ -52,11 +52,20 @@ curl -sSL https://github.com/cozystack/talm/raw/refs/heads/main/hack/install.sh ## Getting Started -Create new project +Create new project using the interactive wizard: + +```bash +mkdir newcluster +cd newcluster +talm interactive +``` + +Or use the non-interactive command: + ```bash mkdir newcluster cd newcluster -talm init -p cozystack -N myawesomecluster +talm init --preset cozystack --name myawesomecluster ``` Boot Talos Linux node, let's say it has address `1.2.3.4` diff --git a/go.mod b/go.mod index cb96e78..4db9e0d 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/cosi-project/runtime v1.14.1 github.com/distribution/reference v0.6.0 // indirect github.com/docker/cli v29.4.0+incompatible // indirect - github.com/dustin/go-humanize v1.0.1 // indirect + github.com/dustin/go-humanize v1.0.1 github.com/fatih/color v1.19.0 // indirect github.com/foxboron/go-uefi v0.0.0-20251010190908-d29549a44f29 // indirect github.com/gdamore/tcell/v2 v2.13.8 // indirect @@ -94,6 +94,9 @@ require ( filippo.io/age v1.3.1 github.com/BurntSushi/toml v1.6.0 github.com/Masterminds/sprig/v3 v3.3.0 + github.com/charmbracelet/bubbles v1.0.0 + github.com/charmbracelet/bubbletea v1.3.10 + github.com/charmbracelet/lipgloss v1.1.0 github.com/gobwas/glob v0.2.3 github.com/pkg/errors v0.9.1 github.com/siderolabs/talos v1.12.6 @@ -119,6 +122,7 @@ require ( github.com/adrg/xdg v0.5.3 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 // indirect + github.com/atotto/clipboard v0.1.4 // indirect github.com/aws/aws-sdk-go-v2 v1.41.5 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.19.14 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect @@ -130,11 +134,18 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/brianvoe/gofakeit/v7 v7.14.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chai2010/gettext-go v1.0.3 // indirect + github.com/charmbracelet/colorprofile v0.4.1 // indirect + github.com/charmbracelet/x/ansi v0.11.6 // indirect + github.com/charmbracelet/x/cellbuf v0.0.15 // indirect + github.com/charmbracelet/x/term v0.2.2 // indirect github.com/cilium/ebpf v0.21.0 // indirect + github.com/clipperhouse/displaywidth v0.9.0 // indirect + github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.7.0 // indirect github.com/cloudflare/circl v1.6.3 // indirect github.com/containerd/containerd/v2 v2.2.2 // indirect @@ -151,6 +162,7 @@ require ( github.com/docker/docker-credential-helpers v0.9.5 // indirect github.com/emicklei/dot v1.11.0 // indirect github.com/emicklei/go-restful/v3 v3.13.0 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/evanphx/json-patch v5.9.11+incompatible // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/exponent-io/jsonpath v0.0.0-20210407135951-1de76d718b3f // indirect @@ -203,6 +215,7 @@ require ( github.com/lmittmann/tint v1.1.3 // indirect github.com/lucasb-eyer/go-colorful v1.4.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect github.com/mattn/go-runewidth v0.0.23 // indirect github.com/mdlayher/socket v0.6.0 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect @@ -215,6 +228,9 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect github.com/nsf/termbox-go v1.1.1 // indirect @@ -242,6 +258,7 @@ require ( github.com/x448/float16 v0.8.4 // indirect github.com/xiang90/probing v0.0.0-20221125231312-a49e3df8f510 // indirect github.com/xlab/treeprint v1.2.0 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect go.etcd.io/bbolt v1.4.3 // indirect go.etcd.io/etcd/pkg/v3 v3.6.10 // indirect go.etcd.io/etcd/server/v3 v3.6.10 // indirect diff --git a/go.sum b/go.sum index f130b42..b1422db 100644 --- a/go.sum +++ b/go.sum @@ -56,6 +56,8 @@ github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 h1:7Ip0wMmLHLRJdrloD github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY= github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI= @@ -86,6 +88,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBU github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw= github.com/aws/smithy-go v1.24.3 h1:XgOAaUgx+HhVBoP4v8n6HCQoTRDhoMghKqw4LNHsDNg= github.com/aws/smithy-go v1.24.3/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= @@ -100,8 +104,26 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chai2010/gettext-go v1.0.3 h1:9liNh8t+u26xl5ddmWLmsOsdNLwkdRTg5AG+JnTiM80= github.com/chai2010/gettext-go v1.0.3/go.mod h1:y+wnP2cHYaVj19NZhYKAwEMH2CI1gNHeQQ+5AjwawxA= +github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc= +github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E= +github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= +github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= +github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk= +github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8= +github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ= +github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI= +github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q= +github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= +github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= github.com/cilium/ebpf v0.21.0 h1:4dpx1J/B/1apeTmWBH5BkVLayHTkFrMovVPnHEk+l3k= github.com/cilium/ebpf v0.21.0/go.mod h1:1kHKv6Kvh5a6TePP5vvvoMa1bclRyzUXELSs272fmIQ= +github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA= +github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk= github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= @@ -153,6 +175,8 @@ github.com/emicklei/dot v1.11.0 h1:zsrhCuFHAJge/aZIC4N4LdHy5tqYu4tWEaUzIwdYj4Y= github.com/emicklei/dot v1.11.0/go.mod h1:DeV7GvQtIw4h2u73RKBkkFdvVAz0D9fzeJrgPW6gy/s= github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bFY/oTyCes= github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/evanphx/json-patch v5.9.11+incompatible h1:ixHHqfcGvxhWkniF1tWxBHA0yb4Z+d1UQi45df52xW8= github.com/evanphx/json-patch v5.9.11+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU= @@ -313,6 +337,8 @@ github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHP github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs= github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw= @@ -348,6 +374,12 @@ github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFd github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00 h1:n6/2gBQ3RWajuToeY6ZtZTIKv2v7ThUy5KKusIT0yc0= github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00/go.mod h1:Pm3mSP3c5uWn86xMLZ5Sa7JB9GsEZySvHYXCTK4E9q4= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus= @@ -478,6 +510,8 @@ github.com/xiang90/probing v0.0.0-20221125231312-a49e3df8f510 h1:S2dVYn90KE98chq github.com/xiang90/probing v0.0.0-20221125231312-a49e3df8f510/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xlab/treeprint v1.2.0 h1:HzHnuAF1plUN2zGlAFHbSQP2qJ0ZAD3XF5XD7OesXRQ= github.com/xlab/treeprint v1.2.0/go.mod h1:gj5Gd3gPdKtR1ikdDK6fnFLdmIS0X30kTTuNd/WEJu0= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -567,6 +601,7 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/pkg/commands/init.go b/pkg/commands/init.go index b3b927d..5504f22 100644 --- a/pkg/commands/init.go +++ b/pkg/commands/init.go @@ -17,6 +17,7 @@ package commands import ( "bufio" "fmt" + "net/url" "os" "path/filepath" "slices" @@ -25,6 +26,7 @@ import ( "github.com/cozystack/talm/pkg/age" "github.com/cozystack/talm/pkg/generated" + "github.com/cozystack/talm/pkg/wizard" "github.com/spf13/cobra" "gopkg.in/yaml.v3" @@ -92,50 +94,9 @@ var initCmd = &cobra.Command{ return nil }, RunE: func(cmd *cobra.Command, args []string) error { - var ( - secretsBundle *secrets.Bundle - versionContract *config.VersionContract - err error - ) - if initCmdFlags.update { return updateTalmLibraryChart() } - if initCmdFlags.talosVersion != "" { - versionContract, err = config.ParseContractFromVersion(initCmdFlags.talosVersion) - if err != nil { - return fmt.Errorf("invalid talos-version: %w", err) - } - } - - secretsBundle, err = secrets.NewBundle(secrets.NewFixedClock(time.Now()), - versionContract, - ) - if err != nil { - return fmt.Errorf("failed to create secrets bundle: %w", err) - } - var genOptions []generate.Option //nolint:prealloc - // Validate preset only if not using --encrypt or --decrypt - if !initCmdFlags.encrypt && !initCmdFlags.decrypt { - availablePresets, err := generated.AvailablePresets() - if err != nil { - return fmt.Errorf("failed to get available presets: %w", err) - } - if !isValidPreset(initCmdFlags.preset, availablePresets) { - return fmt.Errorf("invalid preset: %s. Valid presets are: %v", initCmdFlags.preset, availablePresets) - } - } - if initCmdFlags.talosVersion != "" { - var versionContract *config.VersionContract - - versionContract, err = config.ParseContractFromVersion(initCmdFlags.talosVersion) - if err != nil { - return fmt.Errorf("invalid talos-version: %w", err) - } - - genOptions = append(genOptions, generate.WithVersionContract(versionContract)) - } - genOptions = append(genOptions, generate.WithSecretsBundle(secretsBundle)) // Handle age encryption logic secretsFile := filepath.Join(Config.RootDir, "secrets.yaml") @@ -282,74 +243,45 @@ var initCmd = &cobra.Command{ return nil } - // If encrypted file exists, decrypt it + // Decrypt existing encrypted files before generation if encryptedSecretsFileExists && !secretsFileExists { if err := age.DecryptSecretsFile(Config.RootDir); err != nil { return fmt.Errorf("failed to decrypt secrets: %w", err) } } - // Write secrets.yaml only if it doesn't exist - if !secretsFileExists { - if err = writeSecretsBundleToFile(secretsBundle); err != nil { - return err - } - secretsFileExists = true // Update flag after creation + // Core project generation (shared with interactive wizard) + kubeconfigName := "kubeconfig" + if Config.GlobalOptions.Kubeconfig != "" { + kubeconfigName = filepath.Base(Config.GlobalOptions.Kubeconfig) + } + if err := GenerateProject(GenerateOptions{ + RootDir: Config.RootDir, + Preset: initCmdFlags.preset, + ClusterName: initCmdFlags.name, + TalosVersion: initCmdFlags.talosVersion, + Version: Config.InitOptions.Version, + Force: initCmdFlags.force, + KubeconfigName: kubeconfigName, + }); err != nil { + return err } - // If secrets.yaml exists but encrypted file doesn't, encrypt it + // Post-generation encryption + secretsFileExists = fileExists(secretsFile) if secretsFileExists && !encryptedSecretsFileExists { - // Generate key if it doesn't exist if !keyFileExists { _, keyCreated, err := age.GenerateKey(Config.RootDir) if err != nil { return fmt.Errorf("failed to generate key: %w", err) } - keyFileExists = true // Update flag after creation keyWasCreated = keyCreated } - - // Encrypt secrets if err := age.EncryptSecretsFile(Config.RootDir); err != nil { return fmt.Errorf("failed to encrypt secrets: %w", err) } } - clusterName := initCmdFlags.name - - // Handle talosconfig encryption logic - talosconfigFile := filepath.Join(Config.RootDir, "talosconfig") - encryptedTalosconfigFile := filepath.Join(Config.RootDir, "talosconfig.encrypted") - talosconfigFileExists := fileExists(talosconfigFile) - encryptedTalosconfigFileExists := fileExists(encryptedTalosconfigFile) - - // If encrypted file exists, decrypt it (don't require key - will generate if needed) - if encryptedTalosconfigFileExists && !talosconfigFileExists { - if _, err := handleTalosconfigEncryption(false); err != nil { - return err - } - talosconfigFileExists = fileExists(talosconfigFile) - } - - // Generate talosconfig only if it doesn't exist - if !talosconfigFileExists { - configBundle, err := gen.GenerateConfigBundle(genOptions, clusterName, "https://192.168.0.1:6443", "", []string{}, []string{}, []string{}) - if err != nil { - return err - } - configBundle.TalosConfig().Contexts[clusterName].Endpoints = []string{"127.0.0.1"} - - data, err := yaml.Marshal(configBundle.TalosConfig()) - if err != nil { - return fmt.Errorf("failed to marshal config: %+v", err) - } - - if err = writeToDestination(data, talosconfigFile, 0o600); err != nil { - return err - } - } - - // Encrypt talosconfig if needed talosKeyCreated, err := handleTalosconfigEncryption(false) if err != nil { return err @@ -358,7 +290,6 @@ var initCmd = &cobra.Command{ keyWasCreated = true } - // Handle kubeconfig encryption logic (check if kubeconfig exists from Chart.yaml) kubeconfigPath := Config.GlobalOptions.Kubeconfig if kubeconfigPath == "" { kubeconfigPath = "kubeconfig" @@ -368,7 +299,6 @@ var initCmd = &cobra.Command{ kubeconfigFileExists := fileExists(kubeconfigFile) encryptedKubeconfigFileExists := fileExists(encryptedKubeconfigFile) - // If encrypted file exists, decrypt it if encryptedKubeconfigFileExists && !kubeconfigFileExists { if err := age.DecryptYAMLFile(Config.RootDir, kubeconfigPath+".encrypted", kubeconfigPath); err != nil { return fmt.Errorf("failed to decrypt kubeconfig: %w", err) @@ -376,9 +306,9 @@ var initCmd = &cobra.Command{ kubeconfigFileExists = true } - // If kubeconfig exists but encrypted file doesn't, encrypt it if kubeconfigFileExists && !encryptedKubeconfigFileExists { - // Ensure key exists + // Re-check key existence (may have been created by talosconfig encryption) + keyFileExists = fileExists(keyFile) if !keyFileExists { _, keyCreated, err := age.GenerateKey(Config.RootDir) if err != nil { @@ -386,58 +316,11 @@ var initCmd = &cobra.Command{ } keyWasCreated = keyCreated } - - // Encrypt kubeconfig if err := age.EncryptYAMLFile(Config.RootDir, kubeconfigPath, kubeconfigPath+".encrypted"); err != nil { return fmt.Errorf("failed to encrypt kubeconfig: %w", err) } } - // Create or update .gitignore file - if err = writeGitignoreFile(); err != nil { - return err - } - - nodesDir := filepath.Join(Config.RootDir, "nodes") - if err := os.MkdirAll(nodesDir, os.ModePerm); err != nil { - return fmt.Errorf("failed to create nodes directory: %w", err) - } - - presetFiles, err := generated.PresetFiles() - if err != nil { - return fmt.Errorf("failed to get preset files: %w", err) - } - - for path, content := range presetFiles { - parts := strings.SplitN(path, "/", 2) - chartName := parts[0] - // Write preset files - if chartName == initCmdFlags.preset { - file := filepath.Join(Config.RootDir, filepath.Join(parts[1:]...)) - if parts[len(parts)-1] == "Chart.yaml" { - err = writeToDestination(fmt.Appendf(nil, content, clusterName, Config.InitOptions.Version), file, 0o644) - } else { - err = writeToDestination([]byte(content), file, 0o644) - } - if err != nil { - return err - } - } - // Write library chart - if chartName == "talm" { - file := filepath.Join(Config.RootDir, filepath.Join("charts", path)) - if parts[len(parts)-1] == "Chart.yaml" { - err = writeToDestination(fmt.Appendf(nil, content, "talm", Config.InitOptions.Version), file, 0o644) - } else { - err = writeToDestination([]byte(content), file, 0o644) - } - if err != nil { - return err - } - } - } - - // Print warning about secrets and key backup (only once, at the end, if key was created) if keyWasCreated { printSecretsWarning() } @@ -447,20 +330,6 @@ var initCmd = &cobra.Command{ }, } -func writeSecretsBundleToFile(bundle *secrets.Bundle) error { - bundleBytes, err := yaml.Marshal(bundle) - if err != nil { - return err - } - - secretsFile := filepath.Join(Config.RootDir, "secrets.yaml") - if err = validateFileExists(secretsFile); err != nil { - return err - } - - return writeToDestination(bundleBytes, secretsFile, 0o600) -} - // readChartYamlPreset reads Chart.yaml and determines the preset name from dependencies func readChartYamlPreset() (string, error) { chartYamlPath := filepath.Join(Config.RootDir, "Chart.yaml") @@ -692,37 +561,259 @@ func init() { // Don't mark preset as required - it's validated in PreRunE based on --encrypt/--decrypt flags } -func isValidPreset(preset string, availablePresets []string) bool { - return slices.Contains(availablePresets, preset) +// GenerateOptions holds options for project generation. +type GenerateOptions struct { + RootDir string + Preset string + ClusterName string + TalosVersion string + Version string // Chart version, e.g. "0.1.0" + Force bool + KubeconfigName string // base filename for kubeconfig in .gitignore (default: "kubeconfig") + ValuesOverrides map[string]interface{} // optional: merge into generated values.yaml + Endpoint string // optional: API server endpoint (e.g. "https://203.0.113.1:6443"). Defaults to placeholder. +} + +// GenerateProject creates a new talm project: secrets, talosconfig, preset files, .gitignore, and nodes directory. +// It does not handle encryption — callers should handle that separately if needed. +func GenerateProject(opts GenerateOptions) error { + var ( + versionContract *config.VersionContract + err error + ) + + if opts.TalosVersion != "" { + versionContract, err = config.ParseContractFromVersion(opts.TalosVersion) + if err != nil { + return fmt.Errorf("invalid talos-version: %w", err) + } + } + + secretsBundle, err := secrets.NewBundle(secrets.NewFixedClock(time.Now()), versionContract) + if err != nil { + return fmt.Errorf("failed to create secrets bundle: %w", err) + } + + // Apply the same DNS-label rules the interactive wizard enforces so both + // entry points produce consistent, valid clusters. + if err := wizard.ValidateClusterName(opts.ClusterName); err != nil { + return err + } + + availablePresets, err := generated.AvailablePresets() + if err != nil { + return fmt.Errorf("failed to get available presets: %w", err) + } + if !isValidPreset(opts.Preset, availablePresets) { + return fmt.Errorf("invalid preset: %s. Valid presets are: %v", opts.Preset, availablePresets) + } + + var genOptions []generate.Option + if versionContract != nil { + genOptions = append(genOptions, generate.WithVersionContract(versionContract)) + } + genOptions = append(genOptions, generate.WithSecretsBundle(secretsBundle)) + + // Write secrets.yaml + secretsFile := filepath.Join(opts.RootDir, "secrets.yaml") + if err := writeFileIfNotExists(secretsFile, opts.Force, func() ([]byte, error) { + return yaml.Marshal(secretsBundle) + }, 0o600); err != nil { + return err + } + + // Generate and write talosconfig + talosconfigFile := filepath.Join(opts.RootDir, "talosconfig") + if err := writeFileIfNotExists(talosconfigFile, opts.Force, func() ([]byte, error) { + endpoint := opts.Endpoint + if endpoint == "" { + endpoint = "https://192.168.0.1:6443" + } + configBundle, err := gen.GenerateConfigBundle(genOptions, opts.ClusterName, endpoint, "", []string{}, []string{}, []string{}) + if err != nil { + return nil, err + } + // Default Endpoints is the loopback-only 127.0.0.1 — replace with the + // actual host from the user-provided endpoint so talosconfig points + // at the cluster out of the box. + endpointHost := "127.0.0.1" + if u, err := url.Parse(endpoint); err == nil && u.Hostname() != "" { + endpointHost = u.Hostname() + } + configBundle.TalosConfig().Contexts[opts.ClusterName].Endpoints = []string{endpointHost} + return yaml.Marshal(configBundle.TalosConfig()) + }, 0o600); err != nil { + return err + } + + // Create nodes directory + nodesDir := filepath.Join(opts.RootDir, "nodes") + if err := os.MkdirAll(nodesDir, os.ModePerm); err != nil { + return fmt.Errorf("failed to create nodes directory: %w", err) + } + + // Write preset and library chart files + presetFiles, err := generated.PresetFiles() + if err != nil { + return fmt.Errorf("failed to get preset files: %w", err) + } + + version := opts.Version + if version == "" { + version = "0.1.0" + } + + for path, content := range presetFiles { + parts := strings.SplitN(path, "/", 2) + chartName := parts[0] + + if chartName == opts.Preset { + file := filepath.Join(opts.RootDir, filepath.Join(parts[1:]...)) + if parts[len(parts)-1] == "Chart.yaml" { + err = writeToFile(file, fmt.Appendf(nil, content, opts.ClusterName, version), opts.Force, 0o644) + } else { + err = writeToFile(file, []byte(content), opts.Force, 0o644) + } + if err != nil { + return err + } + } + + if chartName == "talm" { + file := filepath.Join(opts.RootDir, filepath.Join("charts", path)) + if parts[len(parts)-1] == "Chart.yaml" { + err = writeToFile(file, fmt.Appendf(nil, content, "talm", version), opts.Force, 0o644) + } else { + err = writeToFile(file, []byte(content), opts.Force, 0o644) + } + if err != nil { + return err + } + } + } + + // Apply values overrides if provided and values.yaml exists + valuesPath := filepath.Join(opts.RootDir, "values.yaml") + if len(opts.ValuesOverrides) > 0 { + if _, statErr := os.Stat(valuesPath); os.IsNotExist(statErr) { + // values.yaml doesn't exist — skip overrides silently + } else if err := mergeValuesOverrides(valuesPath, opts.ValuesOverrides); err != nil { + return fmt.Errorf("failed to apply values overrides: %w", err) + } + } + + // Write .gitignore + kubeconfigName := opts.KubeconfigName + if kubeconfigName == "" { + kubeconfigName = "kubeconfig" + } + return writeGitignoreForDir(opts.RootDir, kubeconfigName) } -func validateFileExists(file string) error { - if !initCmdFlags.force { - if _, err := os.Stat(file); err == nil { - return fmt.Errorf("file %q already exists, use --force to overwrite, and --update to update Talm library chart only", file) +// mergeValuesOverrides reads an existing values.yaml, applies top-level key overrides, and writes it back. +// This is a shallow merge: each override key REPLACES the entire value at that key, including lists. +// Callers must ensure overrides only contain top-level keys (not nested structures). +// Note: YAML comments and key ordering will not be preserved (marshal/unmarshal round-trip). +func mergeValuesOverrides(valuesPath string, overrides map[string]interface{}) error { + data, err := os.ReadFile(valuesPath) + if err != nil { + return err + } + + var values map[string]interface{} + if err := yaml.Unmarshal(data, &values); err != nil { + return err + } + if values == nil { + values = make(map[string]interface{}) + } + + for k, v := range overrides { + // Reject map-valued overrides entirely to prevent nested structure issues + if _, isMap := v.(map[string]interface{}); isMap { + return fmt.Errorf("map-valued override for key %q is not supported: use flat keys only", k) + } + // Reject overrides that would replace an existing map key + if existing, ok := values[k]; ok { + if _, existingIsMap := existing.(map[string]interface{}); existingIsMap { + return fmt.Errorf("cannot override map key %q: use flat keys only", k) + } + } + values[k] = v + } + + out, err := yaml.Marshal(values) + if err != nil { + return err + } + + return os.WriteFile(valuesPath, out, 0o644) +} + +// writeFileIfNotExists generates content lazily and writes it via writeToFile. +// When force is false and the file exists, it is silently skipped (not an error). +// This allows GenerateProject to be called on existing projects without failing. +func writeFileIfNotExists(path string, force bool, contentFn func() ([]byte, error), perm os.FileMode) error { + if !force { + if _, err := os.Stat(path); err == nil { + return nil // file exists, skip + } + } + + data, err := contentFn() + if err != nil { + return err + } + + // force=true here because we already checked above + return writeToFile(path, data, true, perm) +} + +// writeToFile writes data to a file, creating parent directories as needed. +// When force is false and the file exists, it is silently skipped. +func writeToFile(path string, data []byte, force bool, perm os.FileMode) error { + if !force { + if _, err := os.Stat(path); err == nil { + return nil // file exists, skip } } + if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + if err := os.WriteFile(path, data, perm); err != nil { + return fmt.Errorf("failed to write %s: %w", path, err) + } + + fmt.Fprintf(os.Stderr, "Created %s\n", path) return nil } -func writeGitignoreFile() error { - requiredEntries := []string{"secrets.yaml", "talosconfig", "talm.key"} +func isValidPreset(preset string, availablePresets []string) bool { + return slices.Contains(availablePresets, preset) +} - // Add kubeconfig to required entries (use path from config or default) +func writeGitignoreFile() error { kubeconfigPath := Config.GlobalOptions.Kubeconfig if kubeconfigPath == "" { kubeconfigPath = "kubeconfig" } - // Only add base name (not full path) to gitignore - kubeconfigBase := filepath.Base(kubeconfigPath) - requiredEntries = append(requiredEntries, kubeconfigBase) + return writeGitignoreForDir(Config.RootDir, filepath.Base(kubeconfigPath)) +} - gitignoreFile := filepath.Join(Config.RootDir, ".gitignore") +// writeGitignoreForDir creates or updates .gitignore with required entries. +// kubeconfigName is the base filename of the kubeconfig (e.g. "kubeconfig"). +func writeGitignoreForDir(rootDir string, kubeconfigName string) error { + requiredEntries := []string{"secrets.yaml", "talosconfig", "talm.key", kubeconfigName} + + gitignoreFile := filepath.Join(rootDir, ".gitignore") var existingStr string + fileExisted := false // If .gitignore exists, read it if _, err := os.Stat(gitignoreFile); err == nil { + fileExisted = true existingContent, err := os.ReadFile(gitignoreFile) if err != nil { return fmt.Errorf("failed to read existing .gitignore: %w", err) @@ -764,13 +855,15 @@ func writeGitignoreFile() error { if err := os.MkdirAll(parentDir, os.ModePerm); err != nil { return fmt.Errorf("failed to create output dir: %w", err) } - err := os.WriteFile(gitignoreFile, []byte(existingStr), 0o644) - if _, statErr := os.Stat(gitignoreFile); statErr == nil { + if err := os.WriteFile(gitignoreFile, []byte(existingStr), 0o644); err != nil { + return err + } + if fileExisted { fmt.Fprintf(os.Stderr, "Updated %s\n", gitignoreFile) } else { fmt.Fprintf(os.Stderr, "Created %s\n", gitignoreFile) } - return err + return nil } func fileExists(file string) bool { @@ -861,21 +954,3 @@ func handleTalosconfigEncryption(requireKeyForDecrypt bool) (bool, error) { return keyWasCreated, nil } -func writeToDestination(data []byte, destination string, permissions os.FileMode) error { - if err := validateFileExists(destination); err != nil { - return err - } - - parentDir := filepath.Dir(destination) - - // Create dir path, ignoring "already exists" messages - if err := os.MkdirAll(parentDir, os.ModePerm); err != nil { - return fmt.Errorf("failed to create output dir: %w", err) - } - - err := os.WriteFile(destination, data, permissions) - - fmt.Fprintf(os.Stderr, "Created %s\n", destination) - - return err -} diff --git a/pkg/commands/init_test.go b/pkg/commands/init_test.go new file mode 100644 index 0000000..d30012d --- /dev/null +++ b/pkg/commands/init_test.go @@ -0,0 +1,497 @@ +package commands + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/cozystack/talm/pkg/wizard" +) + +func TestGenerateProject_Generic(t *testing.T) { + rootDir := t.TempDir() + opts := GenerateOptions{ + RootDir: rootDir, + Preset: "generic", + ClusterName: "test-cluster", + Version: "0.1.0", + Force: false, + } + + if err := GenerateProject(opts); err != nil { + t.Fatalf("GenerateProject failed: %v", err) + } + + assertFileExists(t, rootDir, "secrets.yaml") + assertFileExists(t, rootDir, "talosconfig") + assertFileExists(t, rootDir, "Chart.yaml") + assertFileExists(t, rootDir, "values.yaml") + assertFileExists(t, rootDir, ".gitignore") + assertDirExists(t, rootDir, "nodes") + assertDirExists(t, rootDir, "templates") + assertDirExists(t, rootDir, "charts/talm") + assertFileExists(t, rootDir, "charts/talm/Chart.yaml") + assertFileExists(t, rootDir, "charts/talm/templates/_helpers.tpl") + + assertFileContains(t, rootDir, "Chart.yaml", "test-cluster") + assertFileContains(t, rootDir, "Chart.yaml", "0.1.0") + + gitignore := readFile(t, rootDir, ".gitignore") + for _, entry := range []string{"secrets.yaml", "talosconfig", "talm.key", "kubeconfig"} { + if !strings.Contains(gitignore, entry) { + t.Errorf(".gitignore missing entry %q", entry) + } + } +} + +func TestGenerateProject_Cozystack(t *testing.T) { + rootDir := t.TempDir() + opts := GenerateOptions{ + RootDir: rootDir, + Preset: "cozystack", + ClusterName: "cozy-cluster", + Version: "1.0.0", + Force: false, + } + + if err := GenerateProject(opts); err != nil { + t.Fatalf("GenerateProject failed: %v", err) + } + + assertFileExists(t, rootDir, "secrets.yaml") + assertFileExists(t, rootDir, "talosconfig") + assertFileExists(t, rootDir, "Chart.yaml") + assertFileExists(t, rootDir, "values.yaml") + assertFileExists(t, rootDir, "nodes") + + assertFileContains(t, rootDir, "Chart.yaml", "cozy-cluster") + assertFileContains(t, rootDir, "values.yaml", "floatingIP") +} + +func TestGenerateProject_InvalidPreset(t *testing.T) { + rootDir := t.TempDir() + opts := GenerateOptions{ + RootDir: rootDir, + Preset: "nonexistent", + ClusterName: "test", + Version: "0.1.0", + } + + err := GenerateProject(opts) + if err == nil { + t.Fatal("expected error for invalid preset, got nil") + } + if !strings.Contains(err.Error(), "invalid preset") { + t.Errorf("expected 'invalid preset' error, got: %v", err) + } +} + +func TestGenerateProject_SkipsExistingWithoutForce(t *testing.T) { + rootDir := t.TempDir() + + // Create existing secrets.yaml with known content + secretsFile := filepath.Join(rootDir, "secrets.yaml") + if err := os.WriteFile(secretsFile, []byte("existing-secret"), 0o600); err != nil { + t.Fatal(err) + } + + opts := GenerateOptions{ + RootDir: rootDir, + Preset: "generic", + ClusterName: "test", + Version: "0.1.0", + Force: false, + } + + // Should succeed, skipping existing files + if err := GenerateProject(opts); err != nil { + t.Fatalf("GenerateProject should skip existing files, got error: %v", err) + } + + // Verify existing file was NOT overwritten + content := readFile(t, rootDir, "secrets.yaml") + if content != "existing-secret" { + t.Error("secrets.yaml was overwritten despite Force=false") + } + + // But new files should still be created + assertFileExists(t, rootDir, "Chart.yaml") + assertFileExists(t, rootDir, "talosconfig") +} + +func TestGenerateProject_EmptyClusterName(t *testing.T) { + rootDir := t.TempDir() + opts := GenerateOptions{ + RootDir: rootDir, + Preset: "generic", + ClusterName: "", + Version: "0.1.0", + } + + err := GenerateProject(opts) + if err == nil { + t.Fatal("expected error for empty cluster name") + } + if !strings.Contains(err.Error(), "cluster name") { + t.Errorf("expected cluster name error, got: %v", err) + } +} + +// §6 — GenerateProject must reject names that would fail wizard validation +// (uppercase, dots, leading/trailing hyphens, etc.), so both entry points +// share the same rules. + +func TestGenerateProject_RejectsInvalidClusterName(t *testing.T) { + invalid := []string{ + "MyCluster", // uppercase + "my.cluster", // dot + "-cluster", // leading hyphen + "cluster-", // trailing hyphen + "my_cluster", // underscore + strings.Repeat("a", 64), // too long + } + for _, name := range invalid { + t.Run(name, func(t *testing.T) { + rootDir := t.TempDir() + opts := GenerateOptions{ + RootDir: rootDir, + Preset: "generic", + ClusterName: name, + Version: "0.1.0", + } + if err := GenerateProject(opts); err == nil { + t.Errorf("expected error for cluster name %q, got nil", name) + } + }) + } +} + +// §5 — endpoint specified via GenerateOptions must end up in the generated +// talosconfig instead of the hardcoded placeholder. + +func TestGenerateProject_UsesProvidedEndpoint(t *testing.T) { + rootDir := t.TempDir() + opts := GenerateOptions{ + RootDir: rootDir, + Preset: "generic", + ClusterName: "test-cluster", + Version: "0.1.0", + Endpoint: "https://203.0.113.10:6443", + } + + if err := GenerateProject(opts); err != nil { + t.Fatalf("GenerateProject failed: %v", err) + } + + data := readFile(t, rootDir, "talosconfig") + if !strings.Contains(data, "203.0.113.10") { + t.Errorf("talosconfig should reference provided endpoint host, got:\n%s", data) + } + if strings.Contains(data, "192.168.0.1") { + t.Errorf("talosconfig still contains hardcoded placeholder 192.168.0.1:\n%s", data) + } + if strings.Contains(data, "127.0.0.1") { + t.Errorf("talosconfig endpoints should use provided host, not 127.0.0.1:\n%s", data) + } +} + +func TestMergeValuesOverrides_RejectsMapValuedOverrideForNewKey(t *testing.T) { + tmpDir := t.TempDir() + valuesPath := filepath.Join(tmpDir, "values.yaml") + + content := "endpoint: \"https://old:6443\"\n" + if err := os.WriteFile(valuesPath, []byte(content), 0o644); err != nil { + t.Fatal(err) + } + + overrides := map[string]interface{}{ + "newNestedKey": map[string]interface{}{"nested": "value"}, + } + + err := mergeValuesOverrides(valuesPath, overrides) + if err == nil { + t.Fatal("expected error for map-valued override, got nil") + } +} + +func TestGenerateProject_ForceOverwrite(t *testing.T) { + rootDir := t.TempDir() + + // Create existing secrets.yaml + secretsFile := filepath.Join(rootDir, "secrets.yaml") + if err := os.WriteFile(secretsFile, []byte("existing"), 0o600); err != nil { + t.Fatal(err) + } + + opts := GenerateOptions{ + RootDir: rootDir, + Preset: "generic", + ClusterName: "test", + Version: "0.1.0", + Force: true, + } + + if err := GenerateProject(opts); err != nil { + t.Fatalf("GenerateProject with force failed: %v", err) + } + + content := readFile(t, rootDir, "secrets.yaml") + if content == "existing" { + t.Error("secrets.yaml was not overwritten with force=true") + } +} + +func TestGenerateProject_DefaultVersion(t *testing.T) { + rootDir := t.TempDir() + opts := GenerateOptions{ + RootDir: rootDir, + Preset: "generic", + ClusterName: "test", + Version: "", // should default to "0.1.0" + } + + if err := GenerateProject(opts); err != nil { + t.Fatalf("GenerateProject failed: %v", err) + } + + assertFileContains(t, rootDir, "Chart.yaml", "0.1.0") +} + +func TestGenerateProject_ValuesOverrides(t *testing.T) { + rootDir := t.TempDir() + opts := GenerateOptions{ + RootDir: rootDir, + Preset: "generic", + ClusterName: "test", + Version: "0.1.0", + ValuesOverrides: map[string]interface{}{ + "endpoint": "https://10.0.0.1:6443", + }, + } + + if err := GenerateProject(opts); err != nil { + t.Fatalf("GenerateProject failed: %v", err) + } + + content := readFile(t, rootDir, "values.yaml") + if !strings.Contains(content, "https://10.0.0.1:6443") { + t.Errorf("values.yaml should contain overridden endpoint, got:\n%s", content) + } + // Original default endpoint should be replaced + if strings.Contains(content, "192.168.100.10") { + t.Error("values.yaml still contains default endpoint after override") + } +} + +func TestGenerateProject_ValuesOverridesPreservesOtherFields(t *testing.T) { + rootDir := t.TempDir() + opts := GenerateOptions{ + RootDir: rootDir, + Preset: "generic", + ClusterName: "test", + Version: "0.1.0", + ValuesOverrides: map[string]interface{}{ + "endpoint": "https://custom:6443", + }, + } + + if err := GenerateProject(opts); err != nil { + t.Fatalf("GenerateProject failed: %v", err) + } + + // podSubnets should still be present from preset defaults + assertFileContains(t, rootDir, "values.yaml", "podSubnets") + assertFileContains(t, rootDir, "values.yaml", "serviceSubnets") +} + +func TestMergeValuesOverrides_RejectsNestedMaps(t *testing.T) { + tmpDir := t.TempDir() + valuesPath := filepath.Join(tmpDir, "values.yaml") + + // Write a values.yaml with a nested map + content := "network:\n podSubnets:\n - 10.244.0.0/16\n serviceSubnets:\n - 10.96.0.0/16\n" + if err := os.WriteFile(valuesPath, []byte(content), 0o644); err != nil { + t.Fatal(err) + } + + // Attempt to override the nested map — should be rejected + overrides := map[string]interface{}{ + "network": map[string]interface{}{ + "podSubnets": []string{"custom"}, + }, + } + + err := mergeValuesOverrides(valuesPath, overrides) + if err == nil { + t.Fatal("expected error for map override, got nil") + } + if !strings.Contains(err.Error(), "not supported") { + t.Errorf("expected 'not supported' error, got: %v", err) + } +} + +func TestMergeValuesOverrides_ListReplacement(t *testing.T) { + tmpDir := t.TempDir() + valuesPath := filepath.Join(tmpDir, "values.yaml") + + content := "podSubnets:\n- 10.244.0.0/16\n- 10.245.0.0/16\n" + if err := os.WriteFile(valuesPath, []byte(content), 0o644); err != nil { + t.Fatal(err) + } + + overrides := map[string]interface{}{ + "podSubnets": []string{"10.244.0.0/16"}, + } + + if err := mergeValuesOverrides(valuesPath, overrides); err != nil { + t.Fatal(err) + } + + data, _ := os.ReadFile(valuesPath) + // List should be replaced entirely (only 1 entry, not 2) + if strings.Contains(string(data), "10.245.0.0/16") { + t.Error("second subnet should have been replaced (shallow merge replaces entire list)") + } +} + +func TestGenerateProject_Idempotent(t *testing.T) { + rootDir := t.TempDir() + opts := GenerateOptions{ + RootDir: rootDir, + Preset: "generic", + ClusterName: "test", + Version: "0.1.0", + Force: false, + } + + if err := GenerateProject(opts); err != nil { + t.Fatalf("first GenerateProject failed: %v", err) + } + + secretsBefore := readFile(t, rootDir, "secrets.yaml") + + // Run again — should succeed and NOT overwrite existing files + if err := GenerateProject(opts); err != nil { + t.Fatalf("second GenerateProject should be idempotent, got: %v", err) + } + + secretsAfter := readFile(t, rootDir, "secrets.yaml") + if secretsBefore != secretsAfter { + t.Error("secrets.yaml was overwritten on idempotent re-run") + } +} + +func TestMergeValuesOverrides_RejectsScalarOverMap(t *testing.T) { + tmpDir := t.TempDir() + valuesPath := filepath.Join(tmpDir, "values.yaml") + + content := "network:\n podSubnets:\n - 10.244.0.0/16\n" + if err := os.WriteFile(valuesPath, []byte(content), 0o644); err != nil { + t.Fatal(err) + } + + overrides := map[string]interface{}{ + "network": "flat-value", + } + + err := mergeValuesOverrides(valuesPath, overrides) + if err == nil { + t.Fatal("expected error when scalar replaces map, got nil") + } +} + +func TestMergeValuesOverrides_FlatKeysWork(t *testing.T) { + tmpDir := t.TempDir() + valuesPath := filepath.Join(tmpDir, "values.yaml") + + content := "endpoint: \"https://old:6443\"\npodSubnets:\n- 10.244.0.0/16\n" + if err := os.WriteFile(valuesPath, []byte(content), 0o644); err != nil { + t.Fatal(err) + } + + overrides := map[string]interface{}{ + "endpoint": "https://new:6443", + } + + if err := mergeValuesOverrides(valuesPath, overrides); err != nil { + t.Fatal(err) + } + + data, _ := os.ReadFile(valuesPath) + if !strings.Contains(string(data), "https://new:6443") { + t.Error("endpoint not updated") + } + if !strings.Contains(string(data), "podSubnets") { + t.Error("podSubnets lost after merge") + } +} + +func TestBuildValuesOverrides_EmptyEndpoint(t *testing.T) { + result := wizard.WizardResult{Endpoint: ""} + overrides := buildValuesOverrides(result) + if _, ok := overrides["endpoint"]; ok { + t.Error("empty endpoint should not be included in overrides") + } +} + +func TestBuildValuesOverrides_PopulatesFields(t *testing.T) { + result := wizard.WizardResult{ + Endpoint: "https://10.0.0.1:6443", + PodSubnets: "10.244.0.0/16", + ServiceSubnets: "10.96.0.0/16", + AdvertisedSubnets: "192.168.1.0/24", + FloatingIP: "10.0.0.100", + } + overrides := buildValuesOverrides(result) + + if overrides["endpoint"] != "https://10.0.0.1:6443" { + t.Errorf("endpoint = %v", overrides["endpoint"]) + } + if overrides["floatingIP"] != "10.0.0.100" { + t.Errorf("floatingIP = %v", overrides["floatingIP"]) + } + if _, ok := overrides["podSubnets"]; !ok { + t.Error("podSubnets missing") + } +} + +// Test helpers + +func assertFileExists(t *testing.T, rootDir, relPath string) { + t.Helper() + path := filepath.Join(rootDir, relPath) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("expected file %s to exist", relPath) + } +} + +func assertDirExists(t *testing.T, rootDir, relPath string) { + t.Helper() + path := filepath.Join(rootDir, relPath) + info, err := os.Stat(path) + if os.IsNotExist(err) { + t.Errorf("expected directory %s to exist", relPath) + return + } + if !info.IsDir() { + t.Errorf("expected %s to be a directory", relPath) + } +} + +func assertFileContains(t *testing.T, rootDir, relPath, substring string) { + t.Helper() + content := readFile(t, rootDir, relPath) + if !strings.Contains(content, substring) { + t.Errorf("file %s does not contain %q", relPath, substring) + } +} + +func readFile(t *testing.T, rootDir, relPath string) string { + t.Helper() + data, err := os.ReadFile(filepath.Join(rootDir, relPath)) + if err != nil { + t.Fatalf("failed to read %s: %v", relPath, err) + } + return string(data) +} diff --git a/pkg/commands/interactive_init.go b/pkg/commands/interactive_init.go new file mode 100644 index 0000000..b4cf5e9 --- /dev/null +++ b/pkg/commands/interactive_init.go @@ -0,0 +1,183 @@ +// Copyright Cozystack Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package commands + +import ( + "fmt" + "os" + "path/filepath" + + tea "github.com/charmbracelet/bubbletea" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/cozystack/talm/pkg/generated" + "github.com/cozystack/talm/pkg/wizard" + "github.com/cozystack/talm/pkg/wizard/scan" + "github.com/cozystack/talm/pkg/wizard/tui" +) + +// interactiveCmd starts terminal TUI for interactive configuration. +// Registered as a root-level command (not under init) to avoid flag conflicts +// with the existing init command's --encrypt/--decrypt/--update flag validation. +var interactiveCmd = &cobra.Command{ + Use: "interactive", + Short: "Start interactive TUI wizard for cluster initialization", + Long: `Start a terminal-based UI (TUI) wizard that guides through cluster initialization.`, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + presets, err := generated.AvailablePresets() + if err != nil { + return fmt.Errorf("failed to get available presets: %w", err) + } + + scanner := scan.New() + existing, isExisting := detectExistingProject(Config.RootDir) + + var projectGenerated bool + generateFn := func(result wizard.WizardResult) error { + overrides := buildValuesOverrides(result) + + // Skip full project scaffolding when the project is already + // initialized — only (re)write node stubs and values overrides. + if !isExisting { + if err := GenerateProject(GenerateOptions{ + RootDir: Config.RootDir, + Preset: result.Preset, + ClusterName: result.ClusterName, + TalosVersion: Config.TemplateOptions.TalosVersion, + Force: false, + Version: Config.InitOptions.Version, + ValuesOverrides: overrides, + Endpoint: result.Endpoint, + }); err != nil { + return err + } + } else { + valuesPath := filepath.Join(Config.RootDir, "values.yaml") + if err := mergeValuesOverrides(valuesPath, overrides); err != nil { + return err + } + } + + if err := wizard.WriteNodeFiles(Config.RootDir, result.Nodes); err != nil { + return err + } + + projectGenerated = true + return nil + } + + var model tui.Model + if isExisting { + model = tui.NewForExistingProject(scanner, existing, generateFn) + } else { + model = tui.New(scanner, presets, generateFn) + } + p := tea.NewProgram(model, tea.WithAltScreen()) + + finalModel, err := p.Run() + if err != nil { + return fmt.Errorf("wizard failed: %w", err) + } + + if m, ok := finalModel.(tui.Model); ok && m.Err() != nil { + return m.Err() + } + + // p.Run() has returned — alternate screen is torn down and the main + // terminal buffer is restored. Emit the encryption warning here so + // users actually see it. + if projectGenerated { + fmt.Fprintln(os.Stderr, "\nNote: Secrets are not encrypted. Run 'talm init --encrypt' to encrypt sensitive files.") + } + + return nil + }, +} + +// detectExistingProject returns the pre-populated wizard result (preset + +// cluster name) when rootDir already looks like an initialized talm project. +// Allows the wizard to skip steps the user has already answered. +func detectExistingProject(rootDir string) (wizard.WizardResult, bool) { + secretsExist := fileExists(filepath.Join(rootDir, "secrets.yaml")) || + fileExists(filepath.Join(rootDir, "secrets.yaml.encrypted")) + chartYaml := filepath.Join(rootDir, "Chart.yaml") + if !secretsExist || !fileExists(chartYaml) { + return wizard.WizardResult{}, false + } + + data, err := os.ReadFile(chartYaml) + if err != nil { + return wizard.WizardResult{}, false + } + var parsed struct { + Name string `yaml:"name"` + Dependencies []struct { + Name string `yaml:"name"` + } `yaml:"dependencies"` + } + if err := yaml.Unmarshal(data, &parsed); err != nil { + return wizard.WizardResult{}, false + } + + var preset string + for _, dep := range parsed.Dependencies { + if dep.Name != "talm" { + preset = dep.Name + break + } + } + if parsed.Name == "" || preset == "" { + return wizard.WizardResult{}, false + } + return wizard.WizardResult{Preset: preset, ClusterName: parsed.Name}, true +} + +// buildValuesOverrides creates a map of values.yaml overrides from wizard results. +func buildValuesOverrides(result wizard.WizardResult) map[string]interface{} { + overrides := map[string]interface{}{} + + if result.Endpoint != "" { + overrides["endpoint"] = result.Endpoint + } + + if result.PodSubnets != "" { + overrides["podSubnets"] = []string{result.PodSubnets} + } + if result.ServiceSubnets != "" { + overrides["serviceSubnets"] = []string{result.ServiceSubnets} + } + if result.AdvertisedSubnets != "" { + overrides["advertisedSubnets"] = []string{result.AdvertisedSubnets} + } + + // Cozystack-specific + if result.ClusterDomain != "" { + overrides["clusterDomain"] = result.ClusterDomain + } + if result.FloatingIP != "" { + overrides["floatingIP"] = result.FloatingIP + } + if result.Image != "" { + overrides["image"] = result.Image + } + + return overrides +} + +func init() { + addCommand(interactiveCmd) +} diff --git a/pkg/wizard/interfaces.go b/pkg/wizard/interfaces.go new file mode 100644 index 0000000..ae3ee78 --- /dev/null +++ b/pkg/wizard/interfaces.go @@ -0,0 +1,21 @@ +package wizard + +import "context" + +// ScanResult holds the result of a network scan. +type ScanResult struct { + Nodes []NodeInfo + Warnings []string +} + +// Scanner discovers Talos nodes on the network and collects hardware information. +type Scanner interface { + // ScanNetwork discovers Talos nodes in the given CIDR range. + ScanNetwork(ctx context.Context, cidr string) ([]NodeInfo, error) + + // ScanNetworkFull is like ScanNetwork but also returns warnings. + ScanNetworkFull(ctx context.Context, cidr string) (ScanResult, error) + + // GetNodeInfo connects to a single node and retrieves its hardware details. + GetNodeInfo(ctx context.Context, ip string) (NodeInfo, error) +} diff --git a/pkg/wizard/nodefile.go b/pkg/wizard/nodefile.go new file mode 100644 index 0000000..64baaf2 --- /dev/null +++ b/pkg/wizard/nodefile.go @@ -0,0 +1,97 @@ +package wizard + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/cozystack/talm/pkg/modeline" +) + +// WriteNodeFiles creates stub node config files in the nodes/ directory. +// Each file contains a modeline pointing to the node's IP and the appropriate template. +// Existing files are not overwritten. +func WriteNodeFiles(rootDir string, nodes []NodeConfig) error { + nodesDir := filepath.Join(rootDir, "nodes") + if err := os.MkdirAll(nodesDir, 0o755); err != nil { + return fmt.Errorf("failed to create nodes directory: %w", err) + } + + // Validate + dedup by the *sanitized* filename so inputs like "cp-1" and + // "../cp-1" can't collide silently. + seen := make(map[string]bool, len(nodes)) + for _, node := range nodes { + safeName := filepath.Base(node.Hostname) + if safeName == "." || safeName == ".." || safeName == "" || strings.ContainsAny(safeName, "/\\") { + return fmt.Errorf("invalid hostname for file creation: %q", node.Hostname) + } + if err := ValidateHostname(safeName); err != nil { + return fmt.Errorf("invalid hostname for file creation: %w", err) + } + if seen[safeName] { + return fmt.Errorf("duplicate hostname after sanitization: %q", safeName) + } + seen[safeName] = true + } + + for _, node := range nodes { + safeName := filepath.Base(node.Hostname) + filePath := filepath.Join(nodesDir, safeName+".yaml") + + if _, err := os.Stat(filePath); err == nil { + fmt.Fprintf(os.Stderr, "Skipping %s (already exists)\n", filePath) + continue + } + + nodeIP := extractIP(node.Addresses) + managementIP := node.ManagementIP + if managementIP == "" { + managementIP = nodeIP + } + + templateFile, err := templateForRole(node.Role) + if err != nil { + return err + } + + ml, err := modeline.GenerateModeline( + []string{nodeIP}, + []string{managementIP}, + []string{templateFile}, + ) + if err != nil { + return fmt.Errorf("failed to generate modeline for %s: %w", node.Hostname, err) + } + + if err := os.WriteFile(filePath, []byte(ml+"\n"), 0o644); err != nil { + return fmt.Errorf("failed to write %s: %w", filePath, err) + } + + fmt.Fprintf(os.Stderr, "Created %s\n", filePath) + } + + return nil +} + +// extractIP returns the IP address without CIDR mask. +func extractIP(address string) string { + if idx := strings.IndexByte(address, '/'); idx >= 0 { + return address[:idx] + } + return address +} + +// templateForRole returns the template file path for the given node role. +// Unknown roles return an error rather than silently falling back to worker — +// that would mask typos like "master" as correctly-generated artifacts. +func templateForRole(role string) (string, error) { + switch role { + case "controlplane": + return "templates/controlplane.yaml", nil + case "worker": + return "templates/worker.yaml", nil + default: + return "", fmt.Errorf("unsupported node role: %q (expected %q or %q)", role, "controlplane", "worker") + } +} diff --git a/pkg/wizard/nodefile_test.go b/pkg/wizard/nodefile_test.go new file mode 100644 index 0000000..8cc5623 --- /dev/null +++ b/pkg/wizard/nodefile_test.go @@ -0,0 +1,294 @@ +package wizard + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestWriteNodeFiles_CreatesFiles(t *testing.T) { + rootDir := t.TempDir() + nodesDir := filepath.Join(rootDir, "nodes") + if err := os.MkdirAll(nodesDir, 0o755); err != nil { + t.Fatal(err) + } + + nodes := []NodeConfig{ + {Hostname: "cp-1", Role: "controlplane", Addresses: "10.0.0.1/24"}, + {Hostname: "worker-1", Role: "worker", Addresses: "10.0.0.2/24"}, + } + + if err := WriteNodeFiles(rootDir, nodes); err != nil { + t.Fatalf("WriteNodeFiles() error = %v", err) + } + + // Check files exist + for _, node := range nodes { + path := filepath.Join(nodesDir, node.Hostname+".yaml") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("expected file %s to exist", path) + } + } +} + +func TestWriteNodeFiles_ModelineContent(t *testing.T) { + rootDir := t.TempDir() + if err := os.MkdirAll(filepath.Join(rootDir, "nodes"), 0o755); err != nil { + t.Fatal(err) + } + + nodes := []NodeConfig{ + {Hostname: "cp-1", Role: "controlplane", Addresses: "10.0.0.1/24"}, + } + + if err := WriteNodeFiles(rootDir, nodes); err != nil { + t.Fatal(err) + } + + data, err := os.ReadFile(filepath.Join(rootDir, "nodes", "cp-1.yaml")) + if err != nil { + t.Fatal(err) + } + content := string(data) + + if !strings.HasPrefix(content, "# talm:") { + t.Errorf("file should start with modeline, got: %s", content[:min(len(content), 50)]) + } + if !strings.Contains(content, `"10.0.0.1"`) { + t.Error("modeline should contain node IP") + } + if !strings.Contains(content, "controlplane.yaml") { + t.Error("modeline should reference controlplane template") + } +} + +func TestWriteNodeFiles_WorkerTemplate(t *testing.T) { + rootDir := t.TempDir() + if err := os.MkdirAll(filepath.Join(rootDir, "nodes"), 0o755); err != nil { + t.Fatal(err) + } + + nodes := []NodeConfig{ + {Hostname: "w-1", Role: "worker", Addresses: "10.0.0.5/24"}, + } + + if err := WriteNodeFiles(rootDir, nodes); err != nil { + t.Fatal(err) + } + + data, err := os.ReadFile(filepath.Join(rootDir, "nodes", "w-1.yaml")) + if err != nil { + t.Fatal(err) + } + content := string(data) + + if !strings.Contains(content, "worker.yaml") { + t.Error("modeline should reference worker template") + } +} + +func TestWriteNodeFiles_DoesNotOverwrite(t *testing.T) { + rootDir := t.TempDir() + nodesDir := filepath.Join(rootDir, "nodes") + if err := os.MkdirAll(nodesDir, 0o755); err != nil { + t.Fatal(err) + } + + existing := filepath.Join(nodesDir, "cp-1.yaml") + if err := os.WriteFile(existing, []byte("existing content"), 0o644); err != nil { + t.Fatal(err) + } + + nodes := []NodeConfig{ + {Hostname: "cp-1", Role: "controlplane", Addresses: "10.0.0.1/24"}, + } + + if err := WriteNodeFiles(rootDir, nodes); err != nil { + t.Fatal(err) + } + + data, err := os.ReadFile(existing) + if err != nil { + t.Fatal(err) + } + if string(data) != "existing content" { + t.Error("existing file was overwritten") + } +} + +func TestWriteNodeFiles_ExtractsIPFromCIDR(t *testing.T) { + rootDir := t.TempDir() + if err := os.MkdirAll(filepath.Join(rootDir, "nodes"), 0o755); err != nil { + t.Fatal(err) + } + + nodes := []NodeConfig{ + {Hostname: "node-1", Role: "worker", Addresses: "192.168.1.100/24"}, + } + + if err := WriteNodeFiles(rootDir, nodes); err != nil { + t.Fatal(err) + } + + data, err := os.ReadFile(filepath.Join(rootDir, "nodes", "node-1.yaml")) + if err != nil { + t.Fatal(err) + } + content := string(data) + + // Should contain bare IP (without /24) in modeline + if !strings.Contains(content, `"192.168.1.100"`) { + t.Errorf("modeline should contain bare IP without mask, got: %s", content) + } + if strings.Contains(content, `/24`) { + t.Error("modeline should not contain CIDR mask") + } +} + +func TestWriteNodeFiles_CreatesNodesDir(t *testing.T) { + rootDir := t.TempDir() + // Don't create nodes/ dir - WriteNodeFiles should create it + + nodes := []NodeConfig{ + {Hostname: "n1", Role: "worker", Addresses: "10.0.0.1/24"}, + } + + if err := WriteNodeFiles(rootDir, nodes); err != nil { + t.Fatal(err) + } + + path := filepath.Join(rootDir, "nodes", "n1.yaml") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Error("file should be created even when nodes/ dir doesn't exist") + } +} + +func TestWriteNodeFiles_PathTraversal(t *testing.T) { + rootDir := t.TempDir() + + nodes := []NodeConfig{ + {Hostname: "../escape", Role: "worker", Addresses: "10.0.0.1/24"}, + } + + if err := WriteNodeFiles(rootDir, nodes); err != nil { + t.Fatal(err) + } + + // Should create nodes/escape.yaml (base name only), NOT ../escape.yaml + escapedPath := filepath.Join(rootDir, "escape.yaml") + if _, err := os.Stat(escapedPath); err == nil { + t.Error("path traversal: file created outside nodes/ directory") + } + + safePath := filepath.Join(rootDir, "nodes", "escape.yaml") + if _, err := os.Stat(safePath); os.IsNotExist(err) { + t.Error("expected file at nodes/escape.yaml (sanitized)") + } +} + +func TestWriteNodeFiles_DuplicateHostnames(t *testing.T) { + rootDir := t.TempDir() + + nodes := []NodeConfig{ + {Hostname: "node-1", Role: "controlplane", Addresses: "10.0.0.1/24"}, + {Hostname: "node-1", Role: "worker", Addresses: "10.0.0.2/24"}, + } + + err := WriteNodeFiles(rootDir, nodes) + if err == nil { + t.Error("expected error for duplicate hostnames") + } +} + +func TestWriteNodeFiles_SlashHostname(t *testing.T) { + rootDir := t.TempDir() + + nodes := []NodeConfig{ + {Hostname: "/", Role: "worker", Addresses: "10.0.0.1/24"}, + } + + err := WriteNodeFiles(rootDir, nodes) + if err == nil { + t.Error("expected error for '/' hostname") + } +} + +func TestWriteNodeFiles_InvalidHostname(t *testing.T) { + rootDir := t.TempDir() + + nodes := []NodeConfig{ + {Hostname: "..", Role: "worker", Addresses: "10.0.0.1/24"}, + } + + err := WriteNodeFiles(rootDir, nodes) + if err == nil { + t.Error("expected error for '..' hostname") + } +} + +// §8 — two hostnames that sanitize to the same safe name must be rejected + +func TestWriteNodeFiles_NormalizedCollision(t *testing.T) { + rootDir := t.TempDir() + + nodes := []NodeConfig{ + {Hostname: "cp-1", Role: "controlplane", Addresses: "10.0.0.1/24"}, + {Hostname: "../cp-1", Role: "worker", Addresses: "10.0.0.2/24"}, + } + + err := WriteNodeFiles(rootDir, nodes) + if err == nil { + t.Error("expected error for hostnames that collide after sanitization") + } +} + +// §8 — unknown role must return an error, not silently fall back to worker + +func TestWriteNodeFiles_UnknownRole(t *testing.T) { + rootDir := t.TempDir() + + nodes := []NodeConfig{ + {Hostname: "master-1", Role: "master", Addresses: "10.0.0.1/24"}, + } + + err := WriteNodeFiles(rootDir, nodes) + if err == nil { + t.Error("expected error for unknown role 'master', got nil (silent worker fallback)") + } +} + +// §4 — when ManagementIP differs from node IP, modeline must carry both +// (nodes = node IP extracted from Addresses; endpoints = ManagementIP) + +func TestWriteNodeFiles_ManagementIPDistinctFromNodeIP(t *testing.T) { + rootDir := t.TempDir() + + nodes := []NodeConfig{ + { + Hostname: "cp-1", + Role: "controlplane", + Addresses: "10.0.0.1/24", + ManagementIP: "203.0.113.5", + }, + } + + if err := WriteNodeFiles(rootDir, nodes); err != nil { + t.Fatal(err) + } + + data, err := os.ReadFile(filepath.Join(rootDir, "nodes", "cp-1.yaml")) + if err != nil { + t.Fatal(err) + } + content := string(data) + + // nodes field must reference the internal address + if !strings.Contains(content, `"10.0.0.1"`) { + t.Errorf("modeline should contain node IP 10.0.0.1, got:\n%s", content) + } + // endpoints field must reference the management IP + if !strings.Contains(content, `"203.0.113.5"`) { + t.Errorf("modeline should contain management IP 203.0.113.5, got:\n%s", content) + } +} diff --git a/pkg/wizard/scan/extract.go b/pkg/wizard/scan/extract.go new file mode 100644 index 0000000..0d6f853 --- /dev/null +++ b/pkg/wizard/scan/extract.go @@ -0,0 +1,90 @@ +package scan + +import ( + "fmt" + + "github.com/cozystack/talm/pkg/wizard" + machineapi "github.com/siderolabs/talos/pkg/machinery/api/machine" + storageapi "github.com/siderolabs/talos/pkg/machinery/api/storage" +) + +// hostnameFromVersion extracts the hostname from a Version gRPC response. +func hostnameFromVersion(resp *machineapi.VersionResponse) string { + if resp == nil || len(resp.Messages) == 0 { + return "" + } + msg := resp.Messages[0] + if msg.Metadata == nil { + return "" + } + return msg.Metadata.Hostname +} + +// disksFromResponse extracts disk information from a Disks gRPC response. +func disksFromResponse(resp *storageapi.DisksResponse) []wizard.Disk { + if resp == nil || len(resp.Messages) == 0 { + return nil + } + + var disks []wizard.Disk + for _, d := range resp.Messages[0].Disks { + disks = append(disks, wizard.Disk{ + DevPath: fmt.Sprintf("/dev/%s", d.DeviceName), + Model: d.Model, + SizeBytes: d.Size, + }) + } + return disks +} + +// memoryFromResponse extracts total memory in bytes from a Memory gRPC response. +// Memtotal is in kB. +func memoryFromResponse(resp *machineapi.MemoryResponse) uint64 { + if resp == nil || len(resp.Messages) == 0 { + return 0 + } + msg := resp.Messages[0] + if msg.Meminfo == nil { + return 0 + } + return msg.Meminfo.Memtotal * 1024 +} + +// linkFromSpec builds a NetInterface from the spec map of a network.LinkStatus +// resource. Returns nil for non-physical links (bonds, vlans, links without +// a PCI/USB bus path). Pure helper — no gRPC. +func linkFromSpec(name string, spec map[string]interface{}) *wizard.NetInterface { + busPath, _ := spec["busPath"].(string) + kind, _ := spec["kind"].(string) + if busPath == "" || kind != "" { + return nil + } + mac, _ := spec["hardwareAddr"].(string) + return &wizard.NetInterface{ + Name: name, + MAC: mac, + } +} + +// addressFromSpec extracts the CIDR address and its link name from the spec of +// a network.AddressStatus resource. Returns empty strings for non-static or +// malformed addresses. +func addressFromSpec(spec map[string]interface{}) (linkName, cidr string) { + linkName, _ = spec["linkName"].(string) + cidr, _ = spec["address"].(string) + return linkName, cidr +} + +// defaultGatewayFromSpec extracts the next-hop gateway IP from the spec of a +// network.RouteStatus resource when it describes a default route. Returns an +// empty string otherwise. +func defaultGatewayFromSpec(spec map[string]interface{}) string { + dest, _ := spec["destination"].(string) + // Default route: destination empty or "0.0.0.0/0" / "::/0". + if dest != "" && dest != "0.0.0.0/0" && dest != "::/0" { + return "" + } + gw, _ := spec["gateway"].(string) + return gw +} + diff --git a/pkg/wizard/scan/extract_test.go b/pkg/wizard/scan/extract_test.go new file mode 100644 index 0000000..d7011b7 --- /dev/null +++ b/pkg/wizard/scan/extract_test.go @@ -0,0 +1,189 @@ +package scan + +import ( + "testing" + + "github.com/siderolabs/talos/pkg/machinery/api/common" + machineapi "github.com/siderolabs/talos/pkg/machinery/api/machine" + storageapi "github.com/siderolabs/talos/pkg/machinery/api/storage" +) + +func TestHostnameFromVersion(t *testing.T) { + tests := []struct { + name string + resp *machineapi.VersionResponse + expected string + }{ + { + name: "with hostname", + resp: &machineapi.VersionResponse{ + Messages: []*machineapi.Version{ + {Metadata: &common.Metadata{Hostname: "talos-cp-1"}}, + }, + }, + expected: "talos-cp-1", + }, + { + name: "nil response", + resp: nil, + expected: "", + }, + { + name: "empty messages", + resp: &machineapi.VersionResponse{ + Messages: []*machineapi.Version{}, + }, + expected: "", + }, + { + name: "nil metadata", + resp: &machineapi.VersionResponse{ + Messages: []*machineapi.Version{ + {Metadata: nil}, + }, + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := hostnameFromVersion(tt.resp) + if got != tt.expected { + t.Errorf("hostnameFromVersion() = %q, want %q", got, tt.expected) + } + }) + } +} + +func TestDisksFromResponse(t *testing.T) { + resp := &storageapi.DisksResponse{ + Messages: []*storageapi.Disks{ + { + Disks: []*storageapi.Disk{ + {DeviceName: "sda", Model: "Samsung SSD", Size: 500107862016}, + {DeviceName: "nvme0n1", Model: "Intel NVMe", Size: 1000204886016}, + }, + }, + }, + } + + disks := disksFromResponse(resp) + if len(disks) != 2 { + t.Fatalf("expected 2 disks, got %d", len(disks)) + } + + if disks[0].DevPath != "/dev/sda" { + t.Errorf("disk[0].DevPath = %q, want /dev/sda", disks[0].DevPath) + } + if disks[0].Model != "Samsung SSD" { + t.Errorf("disk[0].Model = %q, want Samsung SSD", disks[0].Model) + } + if disks[0].SizeBytes != 500107862016 { + t.Errorf("disk[0].SizeBytes = %d, want 500107862016", disks[0].SizeBytes) + } +} + +func TestDisksFromResponse_Nil(t *testing.T) { + disks := disksFromResponse(nil) + if len(disks) != 0 { + t.Errorf("expected 0 disks for nil response, got %d", len(disks)) + } +} + +func TestMemoryFromResponse(t *testing.T) { + resp := &machineapi.MemoryResponse{ + Messages: []*machineapi.Memory{ + { + Meminfo: &machineapi.MemInfo{ + Memtotal: 16384000, // in kB + }, + }, + }, + } + + bytes := memoryFromResponse(resp) + expected := uint64(16384000) * 1024 + if bytes != expected { + t.Errorf("memoryFromResponse() = %d, want %d", bytes, expected) + } +} + +func TestMemoryFromResponse_Nil(t *testing.T) { + if memoryFromResponse(nil) != 0 { + t.Error("expected 0 for nil response") + } +} + +// §9 — linkFromSpec must parse spec via direct type assertion (no YAML round-trip) + +func TestLinkFromSpec_PhysicalInterface(t *testing.T) { + spec := map[string]interface{}{ + "hardwareAddr": "aa:bb:cc:dd:ee:ff", + "busPath": "0000:00:1f.6", + // "kind" absent → physical + } + + iface := linkFromSpec("eth0", spec) + if iface == nil { + t.Fatal("expected NetInterface, got nil") + } + if iface.Name != "eth0" { + t.Errorf("Name = %q, want eth0", iface.Name) + } + if iface.MAC != "aa:bb:cc:dd:ee:ff" { + t.Errorf("MAC = %q", iface.MAC) + } +} + +func TestLinkFromSpec_SkipsVirtual(t *testing.T) { + // Bond: no busPath + if linkFromSpec("bond0", map[string]interface{}{"hardwareAddr": "xx"}) != nil { + t.Error("bond (no busPath) should be skipped") + } + // VLAN: has kind + if linkFromSpec("eth0.10", map[string]interface{}{"busPath": "x", "kind": "vlan"}) != nil { + t.Error("vlan (kind!=\"\") should be skipped") + } +} + +// §2 — addressFromSpec extracts linkName + CIDR for matching to interface + +func TestAddressFromSpec(t *testing.T) { + link, cidr := addressFromSpec(map[string]interface{}{ + "linkName": "eth0", + "address": "10.0.0.5/24", + }) + if link != "eth0" || cidr != "10.0.0.5/24" { + t.Errorf("got (%q, %q), want (eth0, 10.0.0.5/24)", link, cidr) + } +} + +// §2 — defaultGatewayFromSpec returns gateway only for default route + +func TestDefaultGatewayFromSpec_DefaultRoute(t *testing.T) { + gw := defaultGatewayFromSpec(map[string]interface{}{ + "destination": "0.0.0.0/0", + "gateway": "10.0.0.1", + }) + if gw != "10.0.0.1" { + t.Errorf("default route gateway = %q, want 10.0.0.1", gw) + } + + // Empty destination also means default route in COSI output + gw = defaultGatewayFromSpec(map[string]interface{}{"gateway": "10.0.0.2"}) + if gw != "10.0.0.2" { + t.Errorf("empty-destination gateway = %q, want 10.0.0.2", gw) + } +} + +func TestDefaultGatewayFromSpec_NonDefault(t *testing.T) { + gw := defaultGatewayFromSpec(map[string]interface{}{ + "destination": "192.168.1.0/24", + "gateway": "10.0.0.1", + }) + if gw != "" { + t.Errorf("non-default route should return empty, got %q", gw) + } +} + diff --git a/pkg/wizard/scan/scanner.go b/pkg/wizard/scan/scanner.go new file mode 100644 index 0000000..d74dd8f --- /dev/null +++ b/pkg/wizard/scan/scanner.go @@ -0,0 +1,313 @@ +package scan + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "net" + "slices" + "strings" + "sync" + "time" + + "github.com/cosi-project/runtime/pkg/resource" + "github.com/cosi-project/runtime/pkg/resource/meta" + + "github.com/cozystack/talm/pkg/wizard" + "github.com/siderolabs/talos/cmd/talosctl/pkg/talos/helpers" + "github.com/siderolabs/talos/pkg/machinery/client" +) + +const ( + defaultTalosPort = 50000 + defaultTimeout = 30 * time.Second + maxConcurrentJobs = 10 +) + + +// TalosScanner discovers Talos nodes via TCP port scanning and collects +// hardware info via the Talos gRPC API. No external binaries required. +type TalosScanner struct { + Port int + Timeout time.Duration +} + +// New creates a scanner with default settings. +func New() *TalosScanner { + return &TalosScanner{ + Port: defaultTalosPort, + Timeout: defaultTimeout, + } +} + +// ScanNetwork discovers Talos nodes in the given CIDR range by TCP-scanning +// the Talos API port, then querying each discovered node for hardware details. +func (s *TalosScanner) ScanNetwork(ctx context.Context, cidr string) ([]wizard.NodeInfo, error) { + result, err := s.ScanNetworkFull(ctx, cidr) + if err != nil { + return nil, err + } + return result.Nodes, nil +} + +// ScanNetworkFull is like ScanNetwork but also returns warnings about +// nodes that were discovered by TCP but failed gRPC info collection. +func (s *TalosScanner) ScanNetworkFull(ctx context.Context, cidr string) (wizard.ScanResult, error) { + port := s.Port + if port == 0 { + port = defaultTalosPort + } + + ips, err := scanTCPPort(ctx, cidr, port, maxConcurrentJobs) + if err != nil { + return wizard.ScanResult{}, err + } + if len(ips) == 0 { + return wizard.ScanResult{}, nil + } + + return s.collectNodeInfo(ctx, ips) +} + +// GetNodeInfo connects to a single Talos node via gRPC and retrieves +// hostname, disks, memory, and network interface information. +func (s *TalosScanner) GetNodeInfo(ctx context.Context, ip string) (wizard.NodeInfo, error) { + node, _, err := s.getNodeInfoWithWarnings(ctx, ip) + return node, err +} + +// getNodeInfoWithWarnings is like GetNodeInfo but additionally returns non-fatal +// warnings (e.g. failed link listing) so the caller can surface them through +// the UI instead of the terminal while Bubble Tea owns the screen. +func (s *TalosScanner) getNodeInfoWithWarnings(ctx context.Context, ip string) (wizard.NodeInfo, []string, error) { + node := wizard.NodeInfo{IP: ip} + var warnings []string + + timeout := s.Timeout + if timeout == 0 { + timeout = defaultTimeout + } + infoCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + c, err := client.New(infoCtx, + client.WithEndpoints(ip), + client.WithTLSConfig(&tls.Config{InsecureSkipVerify: true}), //nolint:gosec + ) + if err != nil { + return node, warnings, err + } + defer func() { _ = c.Close() }() + + nodeCtx := client.WithNode(infoCtx, ip) + + if versionResp, err := c.Version(nodeCtx); err == nil { + node.Hostname = hostnameFromVersion(versionResp) + } + if disksResp, err := c.Disks(nodeCtx); err == nil { + node.Disks = disksFromResponse(disksResp) + } + if memResp, err := c.Memory(nodeCtx); err == nil { + node.RAMBytes = memoryFromResponse(memResp) + } + + ifaces, linkWarn := s.collectLinks(nodeCtx, c) + warnings = append(warnings, linkWarn...) + + addrs, addrWarn := s.collectAddresses(nodeCtx, c) + warnings = append(warnings, addrWarn...) + // Merge addresses into interfaces by link name. + for i := range ifaces { + if ips, ok := addrs[ifaces[i].Name]; ok { + ifaces[i].IPs = ips + } + } + node.Interfaces = ifaces + + gateway, routeWarn := s.collectDefaultGateway(nodeCtx, c) + warnings = append(warnings, routeWarn...) + node.DefaultGateway = gateway + + if node.Hostname == "" && len(node.Disks) == 0 && node.RAMBytes == 0 { + return node, warnings, fmt.Errorf("node %s: gRPC connected but returned no useful data", ip) + } + + return node, warnings, nil +} + +// collectLinks retrieves network link resources via the COSI API and +// returns physical interfaces only. Non-fatal errors are returned as warnings +// so the wizard can surface them through the TUI instead of the terminal. +func (s *TalosScanner) collectLinks(ctx context.Context, c *client.Client) ([]wizard.NetInterface, []string) { + var ( + interfaces []wizard.NetInterface + warnings []string + ) + + callbackRD := func(_ *meta.ResourceDefinition) error { return nil } + callbackResource := func(_ context.Context, _ string, r resource.Resource, callErr error) error { + if callErr != nil { + return nil + } + spec := specMapFromResource(r) + if spec == nil { + return nil + } + if iface := linkFromSpec(r.Metadata().ID(), spec); iface != nil { + interfaces = append(interfaces, *iface) + } + return nil + } + + if err := helpers.ForEachResource(ctx, c, callbackRD, callbackResource, "network", "links"); err != nil { + warnings = append(warnings, fmt.Sprintf("failed to list network links: %v", err)) + } + + return interfaces, warnings +} + +// collectAddresses returns a map of link name → [CIDR addresses] discovered +// via network.AddressStatus resources. +func (s *TalosScanner) collectAddresses(ctx context.Context, c *client.Client) (map[string][]string, []string) { + result := map[string][]string{} + var warnings []string + + callbackRD := func(_ *meta.ResourceDefinition) error { return nil } + callbackResource := func(_ context.Context, _ string, r resource.Resource, callErr error) error { + if callErr != nil { + return nil + } + spec := specMapFromResource(r) + if spec == nil { + return nil + } + link, cidr := addressFromSpec(spec) + if link == "" || cidr == "" { + return nil + } + result[link] = append(result[link], cidr) + return nil + } + + if err := helpers.ForEachResource(ctx, c, callbackRD, callbackResource, "network", "addressstatuses"); err != nil { + warnings = append(warnings, fmt.Sprintf("failed to list addresses: %v", err)) + } + + return result, warnings +} + +// collectDefaultGateway returns the next-hop IP of the first default route +// found on the node, or an empty string if there isn't one. +func (s *TalosScanner) collectDefaultGateway(ctx context.Context, c *client.Client) (string, []string) { + var ( + gateway string + warnings []string + ) + + callbackRD := func(_ *meta.ResourceDefinition) error { return nil } + callbackResource := func(_ context.Context, _ string, r resource.Resource, callErr error) error { + if callErr != nil || gateway != "" { + return nil + } + spec := specMapFromResource(r) + if spec == nil { + return nil + } + if gw := defaultGatewayFromSpec(spec); gw != "" { + gateway = gw + } + return nil + } + + if err := helpers.ForEachResource(ctx, c, callbackRD, callbackResource, "network", "routestatuses"); err != nil { + warnings = append(warnings, fmt.Sprintf("failed to list routes: %v", err)) + } + + return gateway, warnings +} + +// specMapFromResource extracts the spec map of a COSI resource using a direct +// type assertion on the value produced by resource.MarshalYAML. Avoids the +// YAML round-trip the original implementation used. +func specMapFromResource(r resource.Resource) map[string]interface{} { + yamlData, err := resource.MarshalYAML(r) + if err != nil { + return nil + } + resMap, ok := yamlData.(map[string]interface{}) + if !ok { + return nil + } + spec, ok := resMap["spec"].(map[string]interface{}) + if !ok { + return nil + } + return spec +} + +// collectNodeInfo queries multiple nodes concurrently with bounded parallelism. +// Returns discovered nodes and warnings for nodes that failed gRPC info collection. +func (s *TalosScanner) collectNodeInfo(ctx context.Context, ips []string) (wizard.ScanResult, error) { + var ( + mu sync.Mutex + nodes []wizard.NodeInfo + warnings []string + sem = make(chan struct{}, maxConcurrentJobs) + wg sync.WaitGroup + ) + + for _, ip := range ips { + wg.Add(1) + go func(ip string) { + defer wg.Done() + + select { + case sem <- struct{}{}: + defer func() { <-sem }() + case <-ctx.Done(): + return + } + + node, nodeWarn, err := s.getNodeInfoWithWarnings(ctx, ip) + if err != nil { + mu.Lock() + warnings = append(warnings, fmt.Sprintf("%s: %v", ip, err)) + for _, w := range nodeWarn { + warnings = append(warnings, fmt.Sprintf("%s: %s", ip, w)) + } + mu.Unlock() + return + } + if node.IP == "" { + node.IP = ip + } + + mu.Lock() + nodes = append(nodes, node) + for _, w := range nodeWarn { + warnings = append(warnings, fmt.Sprintf("%s: %s", ip, w)) + } + mu.Unlock() + }(ip) + } + + wg.Wait() + + if len(nodes) == 0 && len(ips) > 0 { + return wizard.ScanResult{Warnings: warnings}, + fmt.Errorf("found %d host(s) with open port %d but none responded as Talos nodes", len(ips), s.Port) + } + + // Sort by IP numerically for deterministic ordering + slices.SortFunc(nodes, func(a, b wizard.NodeInfo) int { + ipA := net.ParseIP(a.IP).To4() + ipB := net.ParseIP(b.IP).To4() + if ipA != nil && ipB != nil { + return bytes.Compare(ipA, ipB) + } + return strings.Compare(a.IP, b.IP) + }) + + return wizard.ScanResult{Nodes: nodes, Warnings: warnings}, nil +} diff --git a/pkg/wizard/scan/scanner_test.go b/pkg/wizard/scan/scanner_test.go new file mode 100644 index 0000000..601c0dd --- /dev/null +++ b/pkg/wizard/scan/scanner_test.go @@ -0,0 +1,21 @@ +package scan + +import ( + "testing" +) + +func TestNew(t *testing.T) { + s := New() + if s.Port != defaultTalosPort { + t.Errorf("Port = %d, want %d", s.Port, defaultTalosPort) + } + if s.Timeout != defaultTimeout { + t.Errorf("Timeout = %v, want %v", s.Timeout, defaultTimeout) + } +} + +// Note: ScanNetwork and GetNodeInfo require real Talos nodes or network +// access, so they are tested via integration tests only. +// Unit tests for the underlying components are in: +// - tcpscan_test.go (TCP port scanning, CIDR expansion) +// - extract_test.go (gRPC response parsing) diff --git a/pkg/wizard/scan/tcpscan.go b/pkg/wizard/scan/tcpscan.go new file mode 100644 index 0000000..d0dafe0 --- /dev/null +++ b/pkg/wizard/scan/tcpscan.go @@ -0,0 +1,128 @@ +package scan + +import ( + "context" + "encoding/binary" + "fmt" + "net" + "sync" + "time" +) + +const dialTimeout = 2 * time.Second + +// scanTCPPort scans all hosts in the given CIDR for an open TCP port. +// Returns a list of IPs that accepted the connection. +// +// Uses a fixed worker pool of maxWorkers goroutines reading from a jobs +// channel — goroutine count stays bounded regardless of the input range. +// A goroutine-per-host approach would spike to thousands for /16 inputs. +func scanTCPPort(ctx context.Context, cidr string, port int, maxWorkers int) ([]string, error) { + hosts, err := enumerateHosts(cidr) + if err != nil { + return nil, fmt.Errorf("failed to enumerate hosts: %w", err) + } + if maxWorkers < 1 { + return nil, fmt.Errorf("maxWorkers must be >= 1, got %d", maxWorkers) + } + + var ( + mu sync.Mutex + results []string + jobs = make(chan string) + wg sync.WaitGroup + ) + + worker := func() { + defer wg.Done() + dialer := net.Dialer{Timeout: dialTimeout} + for ip := range jobs { + addr := net.JoinHostPort(ip, fmt.Sprintf("%d", port)) + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + continue + } + _ = conn.Close() + + mu.Lock() + results = append(results, ip) + mu.Unlock() + } + } + + for range maxWorkers { + wg.Add(1) + go worker() + } + +feed: + for _, host := range hosts { + select { + case jobs <- host.String(): + case <-ctx.Done(): + break feed + } + } + close(jobs) + wg.Wait() + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + return results, nil +} + +// enumerateHosts expands a CIDR notation to a list of usable host IPs. +// It skips the network and broadcast addresses for subnets larger than /31. +func enumerateHosts(cidr string) ([]net.IP, error) { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + return nil, err + } + + ones, bits := ipNet.Mask.Size() + if bits != 32 { + return nil, fmt.Errorf("only IPv4 CIDR is supported, got /%d bits", bits) + } + + // Reject unreasonably large scans (>/16 = 65534 hosts) + if ones < 16 { + return nil, fmt.Errorf("CIDR range /%d is too large (max /%d), would scan %d hosts", ones, 16, 1<<(32-ones)) + } + + // /32 — single host + if ones == 32 { + return []net.IP{ipNet.IP.To4()}, nil + } + + // /31 — point-to-point, both addresses are usable (RFC 3021) + if ones == 31 { + start := ipToUint32(ipNet.IP.To4()) + return []net.IP{uint32ToIP(start), uint32ToIP(start + 1)}, nil + } + + // For /30 and larger: enumerate usable hosts (skip network and broadcast addresses). + // totalHosts includes network + broadcast, so usable = totalHosts - 2. + // start = network + 1 (first usable), end = network + totalHosts - 2 (last usable). + totalHosts := uint32(1) << (32 - ones) + start := ipToUint32(ipNet.IP.To4()) + 1 + end := start + totalHosts - 3 // -1 (inclusive range) -1 (skip broadcast) -1 (start already +1) + + hosts := make([]net.IP, 0, end-start+1) + for i := start; i <= end; i++ { + hosts = append(hosts, uint32ToIP(i)) + } + + return hosts, nil +} + +func ipToUint32(ip net.IP) uint32 { + return binary.BigEndian.Uint32(ip) +} + +func uint32ToIP(n uint32) net.IP { + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, n) + return ip +} diff --git a/pkg/wizard/scan/tcpscan_test.go b/pkg/wizard/scan/tcpscan_test.go new file mode 100644 index 0000000..3d0593b --- /dev/null +++ b/pkg/wizard/scan/tcpscan_test.go @@ -0,0 +1,203 @@ +package scan + +import ( + "context" + "fmt" + "net" + "runtime" + "sync/atomic" + "testing" + "time" +) + +func TestEnumerateHosts(t *testing.T) { + tests := []struct { + name string + cidr string + expected int + wantErr bool + }{ + {"slash 30", "10.0.0.0/30", 2, false}, + {"slash 32", "10.0.0.1/32", 1, false}, + {"slash 31", "10.0.0.0/31", 2, false}, + {"slash 24", "192.168.1.0/24", 254, false}, + {"invalid cidr", "not-a-cidr", 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hosts, err := enumerateHosts(tt.cidr) + if (err != nil) != tt.wantErr { + t.Errorf("enumerateHosts(%q) error = %v, wantErr %v", tt.cidr, err, tt.wantErr) + return + } + if len(hosts) != tt.expected { + t.Errorf("enumerateHosts(%q) returned %d hosts, want %d", tt.cidr, len(hosts), tt.expected) + } + }) + } +} + +func TestEnumerateHosts_SkipsNetworkAndBroadcast(t *testing.T) { + hosts, err := enumerateHosts("10.0.0.0/30") + if err != nil { + t.Fatal(err) + } + + for _, h := range hosts { + ip := h.String() + if ip == "10.0.0.0" { + t.Error("should not include network address 10.0.0.0") + } + if ip == "10.0.0.3" { + t.Error("should not include broadcast address 10.0.0.3") + } + } +} + +func TestScanTCPPort_FindsOpenPort(t *testing.T) { + // Start a real TCP listener + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer func() { _ = listener.Close() }() + + port := listener.Addr().(*net.TCPAddr).Port + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ips, err := scanTCPPort(ctx, "127.0.0.1/32", port, 1) + if err != nil { + t.Fatalf("scanTCPPort() error = %v", err) + } + + if len(ips) != 1 || ips[0] != "127.0.0.1" { + t.Errorf("scanTCPPort() = %v, want [127.0.0.1]", ips) + } +} + +func TestScanTCPPort_NoOpenPort(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + closedPort, err := pickClosedPort(t) + if err != nil { + t.Fatal(err) + } + + ips, err := scanTCPPort(ctx, "127.0.0.1/32", closedPort, 1) + if err != nil { + t.Fatalf("scanTCPPort() error = %v", err) + } + + if len(ips) != 0 { + t.Errorf("scanTCPPort() = %v, want empty", ips) + } +} + +// §11 — pickClosedPort returns a port that is *confirmed* to refuse connections. +// Picks ephemeral ports, closes them, probes with net.Dial to make sure no one +// raced in. Retries on collision. +func pickClosedPort(t *testing.T) (int, error) { + t.Helper() + for i := 0; i < 10; i++ { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return 0, err + } + port := l.Addr().(*net.TCPAddr).Port + _ = l.Close() + + conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 100*time.Millisecond) + if err != nil { + return port, nil // port refused — good + } + _ = conn.Close() + } + return 0, fmt.Errorf("could not find a closed ephemeral port after 10 tries") +} + +func TestScanTCPPort_MultipleHosts(t *testing.T) { + listener1, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer func() { _ = listener1.Close() }() + + port := listener1.Addr().(*net.TCPAddr).Port + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // /32 has only one host — verify concurrency doesn't break anything + ips, err := scanTCPPort(ctx, "127.0.0.1/32", port, 10) + if err != nil { + t.Fatal(err) + } + if len(ips) != 1 { + t.Errorf("expected 1 IP, got %d", len(ips)) + } +} + +func TestEnumerateHosts_RejectsLargeCIDR(t *testing.T) { + _, err := enumerateHosts("10.0.0.0/8") + if err == nil { + t.Fatal("expected error for /8 CIDR, got nil") + } +} + +func TestEnumerateHosts_AcceptsSlash16(t *testing.T) { + hosts, err := enumerateHosts("10.0.0.0/16") + if err != nil { + t.Fatalf("expected /16 to be accepted, got error: %v", err) + } + if len(hosts) != 65534 { + t.Errorf("expected 65534 hosts, got %d", len(hosts)) + } +} + +// §10 — goroutine count must stay bounded by maxWorkers + small overhead, +// regardless of host count. Current goroutine-per-host implementation will +// spike to 1022 (for /22) and fail this test. +func TestScanTCPPort_BoundedGoroutines(t *testing.T) { + baseGoroutines := runtime.NumGoroutine() + + // Use a sink listener so dials succeed/fail cleanly. + // We don't care about results — only about runtime goroutine count. + var peak atomic.Int64 + done := make(chan struct{}) + defer close(done) + + go func() { + ticker := time.NewTicker(1 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-done: + return + case <-ticker.C: + cur := int64(runtime.NumGoroutine()) + if cur > peak.Load() { + peak.Store(cur) + } + } + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + + // /22 = 1022 hosts. With goroutine-per-host, peak goroutines will spike + // far above maxWorkers. + maxWorkers := 10 + _, _ = scanTCPPort(ctx, "127.0.0.0/22", 59999, maxWorkers) + + // Allow: base + maxWorkers (dial workers) + small overhead (ticker, test runtime). + budget := int64(baseGoroutines) + int64(maxWorkers) + 10 + if peak.Load() > budget { + t.Errorf("goroutine peak %d exceeds budget %d (base=%d, maxWorkers=%d) — worker pool not bounded", + peak.Load(), budget, baseGoroutines, maxWorkers) + } +} diff --git a/pkg/wizard/tui/model.go b/pkg/wizard/tui/model.go new file mode 100644 index 0000000..15e8bb6 --- /dev/null +++ b/pkg/wizard/tui/model.go @@ -0,0 +1,755 @@ +package tui + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/charmbracelet/bubbles/spinner" + "github.com/charmbracelet/bubbles/textinput" + tea "github.com/charmbracelet/bubbletea" + + "github.com/cozystack/talm/pkg/wizard" +) + +// step represents a stage in the wizard flow. +type step int + +const ( + stepSelectPreset step = iota + stepClusterName + stepEndpoint + stepScanCIDR + stepScanning + stepManualNodeEntry + stepSelectNodes + stepConfigureNode + stepConfirm + stepGenerating + stepDone + stepError +) + +// Node configuration field indices. +const ( + fieldRole = 0 + fieldHostname = 1 + fieldDisk = 2 + fieldInterface = 3 + fieldAddress = 4 + fieldGateway = 5 + fieldDNS = 6 + fieldManagementIP = 7 // optional, for DNAT / split-horizon setups + nodeFieldCount = 8 +) + +// Message types for async operations. +type ( + scanResultMsg struct { + nodes []wizard.NodeInfo + warnings []string + } + scanErrorMsg struct{ err error } + generateDoneMsg struct{} + generateErrorMsg struct{ err error } +) + +// GenerateFunc is called when the wizard completes to generate the project. +type GenerateFunc func(result wizard.WizardResult) error + +// Model is the bubbletea model for the interactive wizard. +type Model struct { + step step + err error + + // Wizard data + result wizard.WizardResult + presets []string + + // Sub-models + nameInput textinput.Model + endpointInput textinput.Model + cidrInput textinput.Model + manualIPInput textinput.Model + spinner spinner.Model + + // Node selection state + discoveredNodes []wizard.NodeInfo + scanWarnings []string + selectedNodes []int // indices into discoveredNodes + cursor int // for list navigation + + // Manual node entry + manualNodes []wizard.NodeInfo + + // Node configuration state + configuredNodes []wizard.NodeConfig + nodeInputs [nodeFieldCount]textinput.Model + nodeInputFocus int + currentNodeIdx int + + // Dependencies + scanner wizard.Scanner + generateFn GenerateFunc + + // Context for cancelling long-running operations + cancelScan context.CancelFunc + + // Step before error occurred, for returning on Esc (nil = no previous step) + prevStep *step + + // Terminal dimensions + width, height int +} + +// New creates a new wizard model. +func New(scanner wizard.Scanner, presets []string, generateFn GenerateFunc) Model { + s := spinner.New() + s.Spinner = spinner.Dot + + name := textinput.New() + name.Placeholder = "my-cluster" + name.CharLimit = 63 + + endpoint := textinput.New() + endpoint.Placeholder = "https://192.168.0.1:6443" + + cidr := textinput.New() + cidr.Placeholder = "192.168.1.0/24" + + manualIP := textinput.New() + manualIP.Placeholder = "192.168.1.10" + + var nodeInputs [nodeFieldCount]textinput.Model + for i := range nodeInputs { + nodeInputs[i] = textinput.New() + } + nodeInputs[fieldRole].Placeholder = "controlplane" + nodeInputs[fieldHostname].Placeholder = "node-01" + nodeInputs[fieldDisk].Placeholder = "/dev/sda" + nodeInputs[fieldInterface].Placeholder = "eth0" + nodeInputs[fieldAddress].Placeholder = "192.168.1.10/24" + nodeInputs[fieldGateway].Placeholder = "192.168.1.1" + nodeInputs[fieldDNS].Placeholder = "8.8.8.8,1.1.1.1" + nodeInputs[fieldManagementIP].Placeholder = "(optional) reachable IP, default = node address" + + return Model{ + step: stepSelectPreset, + presets: presets, + scanner: scanner, + + nameInput: name, + endpointInput: endpoint, + cidrInput: cidr, + manualIPInput: manualIP, + spinner: s, + nodeInputs: nodeInputs, + generateFn: generateFn, + } +} + +// NewForExistingProject creates a wizard model for a project that is already +// initialized (secrets.yaml + Chart.yaml exist). Preset and cluster name are +// taken from the on-disk state rather than asked again, so the wizard can be +// used to just add or reconfigure nodes on top of an existing project. +func NewForExistingProject(scanner wizard.Scanner, existing wizard.WizardResult, generateFn GenerateFunc) Model { + m := New(scanner, []string{existing.Preset}, generateFn) + m.result.Preset = existing.Preset + m.result.ClusterName = existing.ClusterName + m.step = stepEndpoint + return m +} + +// Init implements tea.Model. +func (m Model) Init() tea.Cmd { + return nil +} + +// Err returns any error that occurred during the wizard. +func (m Model) Err() error { + return m.err +} + +// Result returns the wizard result after completion. +func (m Model) Result() wizard.WizardResult { + return m.result +} + +// Step returns the current step (for testing). +func (m Model) Step() step { + return m.step +} + +// Update implements tea.Model. +func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + return m, nil + + case tea.KeyMsg: + switch msg.String() { + case "ctrl+c": + if m.cancelScan != nil { + m.cancelScan() + } + return m, tea.Quit + case "esc": + return m.handleBack() + } + + case scanResultMsg: + if m.step != stepScanning { + return m, nil // stale result from cancelled scan + } + if m.cancelScan != nil { + m.cancelScan() + m.cancelScan = nil + } + // Start fresh: a rescan must not inherit selection/cursor/warnings + // from the previous discovery, otherwise stale indexes can survive + // into the configure flow and preselect the wrong hosts. + m.discoveredNodes = msg.nodes + m.scanWarnings = msg.warnings + m.selectedNodes = nil + m.cursor = 0 + if len(msg.nodes) == 0 { + m.err = fmt.Errorf("no Talos nodes found in the specified network") + prev := stepScanCIDR + m.prevStep = &prev + m.step = stepError + return m, nil + } + m.step = stepSelectNodes + return m, nil + + case scanErrorMsg: + if m.step != stepScanning { + return m, nil // stale error from cancelled scan + } + if m.cancelScan != nil { + m.cancelScan() + m.cancelScan = nil + } + // prevStep must point at the user-facing step that triggered the + // scan (stepScanCIDR), not at stepScanning — otherwise Esc from + // stepError would land on an inert spinner with no command running. + m.err = msg.err + prev := stepScanCIDR + m.prevStep = &prev + m.step = stepError + return m, nil + + case generateDoneMsg: + m.step = stepDone + return m, nil + + case generateErrorMsg: + // Same reasoning as scanErrorMsg: Esc from stepError must return to + // stepConfirm where the user can retry, not to the stepGenerating + // spinner (no back-path out of there). + m.err = msg.err + prev := stepConfirm + m.prevStep = &prev + m.step = stepError + return m, nil + + case spinner.TickMsg: + if m.step == stepScanning || m.step == stepGenerating { + var cmd tea.Cmd + m.spinner, cmd = m.spinner.Update(msg) + return m, cmd + } + return m, nil + } + + switch m.step { + case stepSelectPreset: + return m.updateSelectPreset(msg) + case stepClusterName: + return m.updateClusterName(msg) + case stepEndpoint: + return m.updateEndpoint(msg) + case stepScanCIDR: + return m.updateScanCIDR(msg) + case stepManualNodeEntry: + return m.updateManualNodeEntry(msg) + case stepSelectNodes: + return m.updateSelectNodes(msg) + case stepConfigureNode: + return m.updateConfigureNode(msg) + case stepConfirm: + return m.updateConfirm(msg) + case stepDone: + return m.updateDone(msg) + case stepError: + return m.updateError(msg) + } + + return m, nil +} + +func (m Model) handleBack() (tea.Model, tea.Cmd) { + switch m.step { + case stepClusterName: + m.step = stepSelectPreset + case stepEndpoint: + m.step = stepClusterName + case stepScanCIDR: + m.step = stepEndpoint + case stepManualNodeEntry: + m.step = stepScanCIDR + m.manualNodes = nil + case stepScanning: + if m.cancelScan != nil { + m.cancelScan() + m.cancelScan = nil + } + m.step = stepScanCIDR + case stepSelectNodes: + m.step = stepScanCIDR + case stepConfigureNode: + if m.currentNodeIdx > 0 { + // Keep the already-saved previous config: when the user returns + // here they're editing, not re-entering. Rehydrate inputs from + // the stored NodeConfig so edits (disk, iface, gw, DNS) survive. + m.currentNodeIdx-- + m.restoreNodeInputs(m.currentNodeIdx) + } else { + m.step = stepSelectNodes + } + case stepConfirm: + // Return to the last configured node for editing. Keep the saved + // entry — the user may just want to tweak one field, not retype + // everything. m.result.Nodes is cleared so confirm-page state is + // recomputed on re-entry. + m.result.Nodes = nil + m.step = stepConfigureNode + if m.currentNodeIdx >= len(m.selectedNodes) { + m.currentNodeIdx = len(m.selectedNodes) - 1 + } + m.restoreNodeInputs(m.currentNodeIdx) + case stepError: + if m.prevStep != nil { + m.step = *m.prevStep + } else { + m.step = stepSelectPreset + } + m.err = nil + m.prevStep = nil + } + return m, nil +} + +func (m Model) updateSelectPreset(msg tea.Msg) (tea.Model, tea.Cmd) { + if keyMsg, ok := msg.(tea.KeyMsg); ok { + switch keyMsg.String() { + case "up", "k": + if m.cursor > 0 { + m.cursor-- + } + case "down", "j": + if m.cursor < len(m.presets)-1 { + m.cursor++ + } + case "enter": + m.result.Preset = m.presets[m.cursor] + m.step = stepClusterName + m.cursor = 0 + return m, m.nameInput.Focus() + } + } + return m, nil +} + +func (m Model) updateClusterName(msg tea.Msg) (tea.Model, tea.Cmd) { + if keyMsg, ok := msg.(tea.KeyMsg); ok && keyMsg.String() == "enter" { + name := m.nameInput.Value() + if err := wizard.ValidateClusterName(name); err != nil { + m.err = err + return m, nil + } + m.result.ClusterName = name + m.err = nil + m.step = stepEndpoint + return m, m.endpointInput.Focus() + } + + var cmd tea.Cmd + m.nameInput, cmd = m.nameInput.Update(msg) + return m, cmd +} + +func (m Model) updateEndpoint(msg tea.Msg) (tea.Model, tea.Cmd) { + if keyMsg, ok := msg.(tea.KeyMsg); ok && keyMsg.String() == "enter" { + endpoint := m.endpointInput.Value() + if err := wizard.ValidateEndpoint(endpoint); err != nil { + m.err = err + return m, nil + } + m.result.Endpoint = endpoint + m.err = nil + m.step = stepScanCIDR + return m, m.cidrInput.Focus() + } + + var cmd tea.Cmd + m.endpointInput, cmd = m.endpointInput.Update(msg) + return m, cmd +} + +func (m Model) updateScanCIDR(msg tea.Msg) (tea.Model, tea.Cmd) { + if keyMsg, ok := msg.(tea.KeyMsg); ok { + switch keyMsg.String() { + case "enter": + cidr := m.cidrInput.Value() + if err := wizard.ValidateCIDR(cidr); err != nil { + m.err = err + return m, nil + } + m.err = nil + m.step = stepScanning + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + m.cancelScan = cancel + return m, tea.Batch( + m.spinner.Tick, + scanNetworkCmd(ctx, m.scanner, cidr), + ) + case "ctrl+s": + m.err = nil + m.step = stepManualNodeEntry + m.manualNodes = nil + return m, m.manualIPInput.Focus() + } + } + + var cmd tea.Cmd + m.cidrInput, cmd = m.cidrInput.Update(msg) + return m, cmd +} + +func (m Model) updateManualNodeEntry(msg tea.Msg) (tea.Model, tea.Cmd) { + if keyMsg, ok := msg.(tea.KeyMsg); ok { + switch keyMsg.String() { + case "enter": + ip := m.manualIPInput.Value() + if ip == "" { + return m, nil + } + if err := wizard.ValidateIP(ip); err != nil { + m.err = err + return m, nil + } + m.err = nil + m.manualNodes = append(m.manualNodes, wizard.NodeInfo{IP: ip}) + m.manualIPInput.SetValue("") + return m, nil + case "ctrl+d": + if len(m.manualNodes) == 0 { + m.err = fmt.Errorf("add at least one node") + return m, nil + } + m.err = nil + m.discoveredNodes = m.manualNodes + // Pre-select all manual nodes + m.selectedNodes = make([]int, len(m.manualNodes)) + for i := range m.manualNodes { + m.selectedNodes[i] = i + } + m.step = stepSelectNodes + return m, nil + } + } + + var cmd tea.Cmd + m.manualIPInput, cmd = m.manualIPInput.Update(msg) + return m, cmd +} + +func (m Model) updateSelectNodes(msg tea.Msg) (tea.Model, tea.Cmd) { + if keyMsg, ok := msg.(tea.KeyMsg); ok { + switch keyMsg.String() { + case "up", "k": + if m.cursor > 0 { + m.cursor-- + } + case "down", "j": + if m.cursor < len(m.discoveredNodes)-1 { + m.cursor++ + } + case " ": + m.toggleNodeSelection() + case "enter": + if len(m.selectedNodes) == 0 { + m.err = fmt.Errorf("select at least one node") + return m, nil + } + m.err = nil + m.currentNodeIdx = 0 + m.configuredNodes = nil + m.step = stepConfigureNode + m.prepareNodeInputs() + return m, m.nodeInputs[fieldRole].Focus() + } + } + return m, nil +} + +func (m *Model) toggleNodeSelection() { + for i, idx := range m.selectedNodes { + if idx == m.cursor { + m.selectedNodes = append(m.selectedNodes[:i], m.selectedNodes[i+1:]...) + return + } + } + m.selectedNodes = append(m.selectedNodes, m.cursor) +} + +func (m *Model) prepareNodeInputs() { + if m.currentNodeIdx >= len(m.selectedNodes) { + return + } + node := m.discoveredNodes[m.selectedNodes[m.currentNodeIdx]] + + // Default role: first node is controlplane, rest are workers + if m.currentNodeIdx == 0 { + m.nodeInputs[fieldRole].SetValue("controlplane") + } else { + m.nodeInputs[fieldRole].SetValue("worker") + } + + m.nodeInputs[fieldHostname].SetValue(node.Hostname) + if len(node.Disks) > 0 { + m.nodeInputs[fieldDisk].SetValue(node.Disks[0].DevPath) + } else { + m.nodeInputs[fieldDisk].SetValue("") + } + if len(node.Interfaces) > 0 { + m.nodeInputs[fieldInterface].SetValue(node.Interfaces[0].Name) + } else { + m.nodeInputs[fieldInterface].SetValue("") + } + if len(node.Interfaces) > 0 && len(node.Interfaces[0].IPs) > 0 { + m.nodeInputs[fieldAddress].SetValue(node.Interfaces[0].IPs[0]) + } else { + m.nodeInputs[fieldAddress].SetValue("") + } + m.nodeInputs[fieldGateway].SetValue(node.DefaultGateway) + // DNS starts empty — no preconceived default, user must choose. + m.nodeInputs[fieldDNS].SetValue("") + m.nodeInputs[fieldManagementIP].SetValue("") + m.nodeInputFocus = 0 +} + +// restoreNodeInputs rehydrates the per-node inputs from a saved NodeConfig — +// used when the user backs into a node they already configured. +func (m *Model) restoreNodeInputs(idx int) { + if idx < 0 || idx >= len(m.configuredNodes) { + m.prepareNodeInputs() + return + } + nc := m.configuredNodes[idx] + m.nodeInputs[fieldRole].SetValue(nc.Role) + m.nodeInputs[fieldHostname].SetValue(nc.Hostname) + m.nodeInputs[fieldDisk].SetValue(nc.DiskPath) + m.nodeInputs[fieldInterface].SetValue(nc.Interface) + m.nodeInputs[fieldAddress].SetValue(nc.Addresses) + m.nodeInputs[fieldGateway].SetValue(nc.Gateway) + m.nodeInputs[fieldDNS].SetValue(strings.Join(nc.DNS, ",")) + m.nodeInputs[fieldManagementIP].SetValue(nc.ManagementIP) + m.nodeInputFocus = 0 +} + +func (m Model) updateConfigureNode(msg tea.Msg) (tea.Model, tea.Cmd) { + if keyMsg, ok := msg.(tea.KeyMsg); ok { + // Role field is a toggle, not a text input — space/left/right flip + // between the only two valid values instead of letting the user + // type a free-form string and then fail validation. + if m.nodeInputFocus == fieldRole { + switch keyMsg.String() { + case " ", "left", "right", "h", "l": + if m.nodeInputs[fieldRole].Value() == "controlplane" { + m.nodeInputs[fieldRole].SetValue("worker") + } else { + m.nodeInputs[fieldRole].SetValue("controlplane") + } + return m, nil + } + } + switch keyMsg.String() { + case "tab": + m.nodeInputFocus = (m.nodeInputFocus + 1) % nodeFieldCount + return m, m.nodeInputs[m.nodeInputFocus].Focus() + case "shift+tab": + m.nodeInputFocus = (m.nodeInputFocus - 1 + nodeFieldCount) % nodeFieldCount + return m, m.nodeInputs[m.nodeInputFocus].Focus() + case "enter": + nc, err := m.validateAndBuildNodeConfig() + if err != nil { + m.err = err + return m, nil + } + m.err = nil + // Update the existing slot when editing, append when adding a + // fresh node. Prevents duplicates after back-navigation. + if m.currentNodeIdx < len(m.configuredNodes) { + m.configuredNodes[m.currentNodeIdx] = nc + } else { + m.configuredNodes = append(m.configuredNodes, nc) + } + m.currentNodeIdx++ + + if m.currentNodeIdx >= len(m.selectedNodes) { + m.result.Nodes = m.configuredNodes + m.step = stepConfirm + return m, nil + } + // Rehydrate from saved config if this node was already visited, + // otherwise start from discovery defaults. + if m.currentNodeIdx < len(m.configuredNodes) { + m.restoreNodeInputs(m.currentNodeIdx) + } else { + m.prepareNodeInputs() + } + return m, m.nodeInputs[fieldRole].Focus() + } + } + + var cmd tea.Cmd + m.nodeInputs[m.nodeInputFocus], cmd = m.nodeInputs[m.nodeInputFocus].Update(msg) + return m, cmd +} + +func (m Model) validateAndBuildNodeConfig() (wizard.NodeConfig, error) { + role := m.nodeInputs[fieldRole].Value() + if err := wizard.ValidateNodeRole(role); err != nil { + return wizard.NodeConfig{}, err + } + + hostname := m.nodeInputs[fieldHostname].Value() + if err := wizard.ValidateHostname(hostname); err != nil { + return wizard.NodeConfig{}, err + } + + diskPath := m.nodeInputs[fieldDisk].Value() + if diskPath == "" { + return wizard.NodeConfig{}, fmt.Errorf("install disk is required") + } + + address := m.nodeInputs[fieldAddress].Value() + if address == "" { + return wizard.NodeConfig{}, fmt.Errorf("address (CIDR) is required") + } + if err := wizard.ValidateCIDR(address); err != nil { + return wizard.NodeConfig{}, fmt.Errorf("address: %w", err) + } + + gateway := m.nodeInputs[fieldGateway].Value() + if gateway != "" { + if err := wizard.ValidateIP(gateway); err != nil { + return wizard.NodeConfig{}, fmt.Errorf("gateway: %w", err) + } + } + + var dns []string + dnsStr := m.nodeInputs[fieldDNS].Value() + if dnsStr != "" { + for _, d := range strings.Split(dnsStr, ",") { + d = strings.TrimSpace(d) + if d == "" { + continue + } + if err := wizard.ValidateIP(d); err != nil { + return wizard.NodeConfig{}, fmt.Errorf("DNS %q: %w", d, err) + } + dns = append(dns, d) + } + } + + managementIP := strings.TrimSpace(m.nodeInputs[fieldManagementIP].Value()) + if managementIP != "" { + if err := wizard.ValidateIP(managementIP); err != nil { + return wizard.NodeConfig{}, fmt.Errorf("management IP: %w", err) + } + } + + return wizard.NodeConfig{ + Hostname: hostname, + Role: role, + DiskPath: diskPath, + Interface: m.nodeInputs[fieldInterface].Value(), + Addresses: address, + Gateway: gateway, + DNS: dns, + ManagementIP: managementIP, + }, nil +} + +func (m Model) updateConfirm(msg tea.Msg) (tea.Model, tea.Cmd) { + if keyMsg, ok := msg.(tea.KeyMsg); ok { + switch keyMsg.String() { + case "y", "enter": + m.step = stepGenerating + return m, tea.Batch( + m.spinner.Tick, + generateCmd(m.generateFn, m.result), + ) + case "n": + m.step = stepSelectPreset + m.configuredNodes = nil + m.selectedNodes = nil + return m, nil + } + } + return m, nil +} + +func (m Model) updateDone(msg tea.Msg) (tea.Model, tea.Cmd) { + if keyMsg, ok := msg.(tea.KeyMsg); ok { + switch keyMsg.String() { + case "enter", "q": + return m, tea.Quit + } + } + return m, nil +} + +func (m Model) updateError(msg tea.Msg) (tea.Model, tea.Cmd) { + if keyMsg, ok := msg.(tea.KeyMsg); ok { + switch keyMsg.String() { + case "enter", "q": + return m, tea.Quit + case "r": + m.step = stepSelectPreset + m.err = nil + return m, nil + } + } + return m, nil +} + +// Async command functions. + +func scanNetworkCmd(ctx context.Context, scanner wizard.Scanner, cidr string) tea.Cmd { + return func() tea.Msg { + result, err := scanner.ScanNetworkFull(ctx, cidr) + if err != nil { + return scanErrorMsg{err: err} + } + return scanResultMsg{nodes: result.Nodes, warnings: result.Warnings} + } +} + +func generateCmd(fn GenerateFunc, result wizard.WizardResult) tea.Cmd { + return func() tea.Msg { + if fn == nil { + return generateDoneMsg{} + } + if err := fn(result); err != nil { + return generateErrorMsg{err: err} + } + return generateDoneMsg{} + } +} diff --git a/pkg/wizard/tui/model_test.go b/pkg/wizard/tui/model_test.go new file mode 100644 index 0000000..e5114d5 --- /dev/null +++ b/pkg/wizard/tui/model_test.go @@ -0,0 +1,952 @@ +package tui + +import ( + "context" + "fmt" + "strings" + "testing" + + tea "github.com/charmbracelet/bubbletea" + + "github.com/cozystack/talm/pkg/wizard" +) + +type mockScanner struct { + nodes []wizard.NodeInfo + err error +} + +func (m *mockScanner) ScanNetwork(_ context.Context, _ string) ([]wizard.NodeInfo, error) { + return m.nodes, m.err +} + +func (m *mockScanner) ScanNetworkFull(_ context.Context, _ string) (wizard.ScanResult, error) { + return wizard.ScanResult{Nodes: m.nodes}, m.err +} + +func (m *mockScanner) GetNodeInfo(_ context.Context, ip string) (wizard.NodeInfo, error) { + for _, n := range m.nodes { + if n.IP == ip { + return n, nil + } + } + return wizard.NodeInfo{IP: ip}, nil +} + +func enterMsg() tea.Msg { + return tea.KeyMsg{Type: tea.KeyEnter} +} + +func escMsg() tea.Msg { + return tea.KeyMsg{Type: tea.KeyEsc} +} + +func keyMsg(key string) tea.Msg { + return tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)} +} + +func TestInitialStep(t *testing.T) { + m := New(&mockScanner{}, []string{"generic", "cozystack"}, nil) + if m.Step() != stepSelectPreset { + t.Errorf("initial step = %d, want stepSelectPreset (%d)", m.Step(), stepSelectPreset) + } +} + +func TestSelectPreset(t *testing.T) { + m := New(&mockScanner{}, []string{"generic", "cozystack"}, nil) + + updated, _ := m.Update(enterMsg()) + m = updated.(Model) + + if m.Step() != stepClusterName { + t.Errorf("step = %d, want stepClusterName (%d)", m.Step(), stepClusterName) + } + if m.result.Preset != "generic" { + t.Errorf("preset = %q, want %q", m.result.Preset, "generic") + } +} + +func TestSelectSecondPreset(t *testing.T) { + m := New(&mockScanner{}, []string{"generic", "cozystack"}, nil) + + updated, _ := m.Update(keyMsg("j")) + m = updated.(Model) + updated, _ = m.Update(enterMsg()) + m = updated.(Model) + + if m.result.Preset != "cozystack" { + t.Errorf("preset = %q, want %q", m.result.Preset, "cozystack") + } +} + +func TestClusterNameValidation(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + + updated, _ := m.Update(enterMsg()) // select preset + m = updated.(Model) + updated, _ = m.Update(enterMsg()) // submit empty name + m = updated.(Model) + + if m.Step() != stepClusterName { + t.Errorf("should stay on stepClusterName with empty name, got step %d", m.Step()) + } + if m.err == nil { + t.Error("expected validation error for empty cluster name") + } +} + +func TestClusterNameSuccess(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + + updated, _ := m.Update(enterMsg()) // select preset + m = updated.(Model) + for _, ch := range "test" { + updated, _ = m.Update(keyMsg(string(ch))) + m = updated.(Model) + } + updated, _ = m.Update(enterMsg()) // submit name + m = updated.(Model) + + if m.Step() != stepEndpoint { + t.Errorf("step = %d, want stepEndpoint (%d)", m.Step(), stepEndpoint) + } + if m.result.ClusterName != "test" { + t.Errorf("clusterName = %q, want %q", m.result.ClusterName, "test") + } +} + +func TestBackNavigation(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + + updated, _ := m.Update(enterMsg()) // go to cluster name + m = updated.(Model) + updated, _ = m.Update(escMsg()) // go back + m = updated.(Model) + + if m.Step() != stepSelectPreset { + t.Errorf("step = %d, want stepSelectPreset (%d)", m.Step(), stepSelectPreset) + } +} + +func TestEndpointValidation(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + + // Navigate to endpoint step + updated, _ := m.Update(enterMsg()) + m = updated.(Model) + for _, ch := range "test" { + updated, _ = m.Update(keyMsg(string(ch))) + m = updated.(Model) + } + updated, _ = m.Update(enterMsg()) + m = updated.(Model) + + // Submit empty endpoint + updated, _ = m.Update(enterMsg()) + m = updated.(Model) + + if m.Step() != stepEndpoint { + t.Errorf("should stay on stepEndpoint with empty value, got step %d", m.Step()) + } + if m.err == nil { + t.Error("expected validation error for empty endpoint") + } +} + +func TestScanResultTransition(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepScanning + + nodes := []wizard.NodeInfo{ + {IP: "10.0.0.1", Hostname: "node-01"}, + {IP: "10.0.0.2", Hostname: "node-02"}, + } + + updated, _ := m.Update(scanResultMsg{nodes: nodes}) + m = updated.(Model) + + if m.Step() != stepSelectNodes { + t.Errorf("step = %d, want stepSelectNodes (%d)", m.Step(), stepSelectNodes) + } + if len(m.discoveredNodes) != 2 { + t.Errorf("discoveredNodes = %d, want 2", len(m.discoveredNodes)) + } +} + +func TestScanResultEmpty(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepScanning + + updated, _ := m.Update(scanResultMsg{nodes: nil}) + m = updated.(Model) + + if m.Step() != stepError { + t.Errorf("step = %d, want stepError (%d)", m.Step(), stepError) + } +} + +func TestScanError(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepScanning + + updated, _ := m.Update(scanErrorMsg{err: fmt.Errorf("nmap failed")}) + m = updated.(Model) + + if m.Step() != stepError { + t.Errorf("step = %d, want stepError (%d)", m.Step(), stepError) + } + if m.err == nil || m.err.Error() != "nmap failed" { + t.Errorf("err = %v, want 'nmap failed'", m.err) + } +} + +func TestNodeSelection(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepSelectNodes + m.discoveredNodes = []wizard.NodeInfo{ + {IP: "10.0.0.1"}, + {IP: "10.0.0.2"}, + } + + // Toggle first node + updated, _ := m.Update(keyMsg(" ")) + m = updated.(Model) + + if len(m.selectedNodes) != 1 || m.selectedNodes[0] != 0 { + t.Errorf("selectedNodes = %v, want [0]", m.selectedNodes) + } + + // Toggle again (deselect) + updated, _ = m.Update(keyMsg(" ")) + m = updated.(Model) + + if len(m.selectedNodes) != 0 { + t.Errorf("selectedNodes = %v, want empty", m.selectedNodes) + } +} + +func TestConfirmToGenerate(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, func(_ wizard.WizardResult) error { + return nil + }) + m.step = stepConfirm + m.result = wizard.WizardResult{ + Preset: "generic", + ClusterName: "test", + Endpoint: "https://10.0.0.1:6443", + } + + updated, _ := m.Update(keyMsg("y")) + m = updated.(Model) + + if m.Step() != stepGenerating { + t.Errorf("step = %d, want stepGenerating (%d)", m.Step(), stepGenerating) + } +} + +func TestGenerateDone(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepGenerating + + updated, _ := m.Update(generateDoneMsg{}) + m = updated.(Model) + + if m.Step() != stepDone { + t.Errorf("step = %d, want stepDone (%d)", m.Step(), stepDone) + } +} + +func TestGenerateError(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepGenerating + + updated, _ := m.Update(generateErrorMsg{err: fmt.Errorf("write failed")}) + m = updated.(Model) + + if m.Step() != stepError { + t.Errorf("step = %d, want stepError (%d)", m.Step(), stepError) + } +} + +func TestWindowResize(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + + updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40}) + m = updated.(Model) + + if m.width != 120 || m.height != 40 { + t.Errorf("dimensions = %dx%d, want 120x40", m.width, m.height) + } +} + +// Manual node entry tests + +func TestSkipScanTransition(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepScanCIDR + + updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyCtrlS}) + m = updated.(Model) + + if m.Step() != stepManualNodeEntry { + t.Errorf("step = %d, want stepManualNodeEntry (%d)", m.Step(), stepManualNodeEntry) + } +} + +func TestManualNodeEntry_AddAndDone(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepManualNodeEntry + + // Set IP directly (textinput doesn't process rune messages without Focus) + m.manualIPInput.SetValue("10.0.0.1") + + // Add it + updated, _ := m.Update(enterMsg()) + m = updated.(Model) + + if len(m.manualNodes) != 1 { + t.Fatalf("expected 1 manual node, got %d", len(m.manualNodes)) + } + if m.manualNodes[0].IP != "10.0.0.1" { + t.Errorf("IP = %q, want 10.0.0.1", m.manualNodes[0].IP) + } + + // Press ctrl+d to finish + updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyCtrlD}) + m = updated.(Model) + + if m.Step() != stepSelectNodes { + t.Errorf("step = %d, want stepSelectNodes (%d)", m.Step(), stepSelectNodes) + } + if len(m.selectedNodes) != 1 { + t.Error("manual nodes should be pre-selected") + } +} + +func TestManualNodeEntry_InvalidIP(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepManualNodeEntry + + m.manualIPInput.SetValue("not-an-ip") + + updated, _ := m.Update(enterMsg()) + m = updated.(Model) + + if m.err == nil { + t.Error("expected validation error for invalid IP") + } + if m.Step() != stepManualNodeEntry { + t.Error("should stay on manual entry step") + } +} + +func TestManualNodeEntry_DoneWithoutNodes(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepManualNodeEntry + + updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyCtrlD}) + m = updated.(Model) + + if m.err == nil { + t.Error("expected error when pressing done with no nodes") + } + if m.Step() != stepManualNodeEntry { + t.Error("should stay on manual entry step") + } +} + +// Node configuration validation tests + +func TestNodeConfigValidation_InvalidRole(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepConfigureNode + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1"}} + m.selectedNodes = []int{0} + m.currentNodeIdx = 0 + m.prepareNodeInputs() + + // Set invalid role + m.nodeInputs[fieldRole].SetValue("master") + m.nodeInputs[fieldHostname].SetValue("node-01") + + updated, _ := m.Update(enterMsg()) + m = updated.(Model) + + if m.err == nil { + t.Error("expected validation error for invalid role") + } + if m.Step() != stepConfigureNode { + t.Error("should stay on configure step on validation error") + } +} + +func TestNodeConfigValidation_InvalidHostname(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepConfigureNode + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1"}} + m.selectedNodes = []int{0} + m.currentNodeIdx = 0 + m.prepareNodeInputs() + + m.nodeInputs[fieldRole].SetValue("controlplane") + m.nodeInputs[fieldHostname].SetValue("-bad-name") + + updated, _ := m.Update(enterMsg()) + m = updated.(Model) + + if m.err == nil { + t.Error("expected validation error for invalid hostname") + } +} + +func TestNodeConfigValidation_Success(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepConfigureNode + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1"}} + m.selectedNodes = []int{0} + m.currentNodeIdx = 0 + m.prepareNodeInputs() + + m.nodeInputs[fieldRole].SetValue("controlplane") + m.nodeInputs[fieldHostname].SetValue("cp-1") + m.nodeInputs[fieldDisk].SetValue("/dev/sda") + m.nodeInputs[fieldInterface].SetValue("eth0") + m.nodeInputs[fieldAddress].SetValue("10.0.0.1/24") + m.nodeInputs[fieldGateway].SetValue("10.0.0.254") + m.nodeInputs[fieldDNS].SetValue("8.8.8.8,1.1.1.1") + + updated, _ := m.Update(enterMsg()) + m = updated.(Model) + + if m.Step() != stepConfirm { + t.Errorf("step = %d, want stepConfirm (%d), err = %v", m.Step(), stepConfirm, m.err) + } + if len(m.result.Nodes) != 1 { + t.Fatalf("expected 1 configured node, got %d", len(m.result.Nodes)) + } + n := m.result.Nodes[0] + if n.Role != "controlplane" { + t.Errorf("role = %q, want controlplane", n.Role) + } + if n.Gateway != "10.0.0.254" { + t.Errorf("gateway = %q, want 10.0.0.254", n.Gateway) + } + if len(n.DNS) != 2 || n.DNS[0] != "8.8.8.8" || n.DNS[1] != "1.1.1.1" { + t.Errorf("DNS = %v, want [8.8.8.8 1.1.1.1]", n.DNS) + } +} + +func TestNodeConfigValidation_EmptyAddress(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepConfigureNode + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1"}} + m.selectedNodes = []int{0} + m.currentNodeIdx = 0 + m.prepareNodeInputs() + + m.nodeInputs[fieldRole].SetValue("controlplane") + m.nodeInputs[fieldHostname].SetValue("cp-1") + m.nodeInputs[fieldDisk].SetValue("/dev/sda") + m.nodeInputs[fieldAddress].SetValue("") + + updated, _ := m.Update(enterMsg()) + m = updated.(Model) + + if m.err == nil { + t.Error("expected validation error for empty address") + } +} + +func TestNodeConfigValidation_EmptyDisk(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepConfigureNode + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1"}} + m.selectedNodes = []int{0} + m.currentNodeIdx = 0 + m.prepareNodeInputs() + + m.nodeInputs[fieldRole].SetValue("controlplane") + m.nodeInputs[fieldHostname].SetValue("cp-1") + m.nodeInputs[fieldDisk].SetValue("") // empty disk + + updated, _ := m.Update(enterMsg()) + m = updated.(Model) + + if m.err == nil { + t.Error("expected validation error for empty disk path") + } + if m.Step() != stepConfigureNode { + t.Error("should stay on configure step") + } +} + +func TestNodeConfigDefaultRole(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepConfigureNode + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1"}, {IP: "10.0.0.2"}} + m.selectedNodes = []int{0, 1} + + m.currentNodeIdx = 0 + m.prepareNodeInputs() + if m.nodeInputs[fieldRole].Value() != "controlplane" { + t.Errorf("first node role = %q, want controlplane", m.nodeInputs[fieldRole].Value()) + } + + m.currentNodeIdx = 1 + m.prepareNodeInputs() + if m.nodeInputs[fieldRole].Value() != "worker" { + t.Errorf("second node role = %q, want worker", m.nodeInputs[fieldRole].Value()) + } +} + +// Verify stale scan results are ignored after cancellation + +func TestStaleScanResult_Ignored(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepScanCIDR // already back from scanning + + // Deliver a stale scan result — should be ignored + updated, _ := m.Update(scanResultMsg{nodes: []wizard.NodeInfo{{IP: "10.0.0.1"}}}) + m = updated.(Model) + + if m.Step() != stepScanCIDR { + t.Errorf("stale scan result should not change step, got %d", m.Step()) + } +} + +// Verify the done step allows exiting the program + +func TestDoneStep_EnterQuits(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepDone + + _, cmd := m.Update(enterMsg()) + if cmd == nil { + t.Fatal("expected tea.Quit cmd on enter at stepDone, got nil") + } +} + +func TestDoneStep_QKeyQuits(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepDone + + _, cmd := m.Update(keyMsg("q")) + if cmd == nil { + t.Fatal("expected tea.Quit cmd on 'q' at stepDone, got nil") + } +} + +// Verify back navigation restores previous node's data in the input fields + +func TestBackFromConfigureNode_RestoresInputs(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepConfigureNode + m.discoveredNodes = []wizard.NodeInfo{ + {IP: "10.0.0.1", Hostname: "first-node"}, + {IP: "10.0.0.2", Hostname: "second-node"}, + } + m.selectedNodes = []int{0, 1} + + // Configure first node + m.currentNodeIdx = 0 + m.prepareNodeInputs() + m.nodeInputs[fieldHostname].SetValue("first-node") + m.nodeInputs[fieldRole].SetValue("controlplane") + m.configuredNodes = append(m.configuredNodes, wizard.NodeConfig{Hostname: "first-node"}) + m.currentNodeIdx = 1 + m.prepareNodeInputs() + + // Now go back + updated, _ := m.Update(escMsg()) + m = updated.(Model) + + if m.currentNodeIdx != 0 { + t.Errorf("currentNodeIdx = %d, want 0", m.currentNodeIdx) + } + // After back, prepareNodeInputs should have restored first-node's hostname + if m.nodeInputs[fieldHostname].Value() != "first-node" { + t.Errorf("hostname = %q, want first-node", m.nodeInputs[fieldHostname].Value()) + } +} + +// Verify back from confirm doesn't panic and restores last node + +func TestBackFromConfirm_NoPanic(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepConfirm + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1", Hostname: "cp-1"}} + m.selectedNodes = []int{0} + m.currentNodeIdx = 1 // past the last node (confirm was reached) + m.configuredNodes = []wizard.NodeConfig{{Hostname: "cp-1", Role: "controlplane"}} + m.result.Nodes = m.configuredNodes + + // Press Esc — should not panic + updated, _ := m.Update(escMsg()) + m = updated.(Model) + + if m.Step() != stepConfigureNode { + t.Errorf("step = %d, want stepConfigureNode", m.Step()) + } + if m.currentNodeIdx != 0 { + t.Errorf("currentNodeIdx = %d, want 0", m.currentNodeIdx) + } +} + +// Verify back from confirm with single node doesn't create duplicates + +func TestBackFromConfirm_SingleNode_NoDuplicate(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepConfigureNode + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1", Hostname: "cp-1"}} + m.selectedNodes = []int{0} + m.currentNodeIdx = 0 + m.prepareNodeInputs() + + // Configure the node + m.nodeInputs[fieldRole].SetValue("controlplane") + m.nodeInputs[fieldHostname].SetValue("cp-1") + m.nodeInputs[fieldDisk].SetValue("/dev/sda") + m.nodeInputs[fieldAddress].SetValue("10.0.0.1/24") + m.nodeInputs[fieldDNS].SetValue("8.8.8.8") + + updated, _ := m.Update(enterMsg()) // -> confirm + m = updated.(Model) + + if m.Step() != stepConfirm { + t.Fatalf("expected stepConfirm, got %d", m.Step()) + } + + // Go back + updated, _ = m.Update(escMsg()) + m = updated.(Model) + + // Re-enter the same node + m.nodeInputs[fieldRole].SetValue("controlplane") + m.nodeInputs[fieldHostname].SetValue("cp-1") + m.nodeInputs[fieldDisk].SetValue("/dev/sda") + m.nodeInputs[fieldAddress].SetValue("10.0.0.1/24") + m.nodeInputs[fieldDNS].SetValue("8.8.8.8") + + updated, _ = m.Update(enterMsg()) // -> confirm again + m = updated.(Model) + + if len(m.result.Nodes) != 1 { + t.Errorf("expected 1 node, got %d (duplicate created on back-forward)", len(m.result.Nodes)) + } +} + +// Verify Esc from scanning cancels context and returns to CIDR step + +func TestEscFromScanning(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepScanning + cancelled := false + m.cancelScan = func() { cancelled = true } + + updated, _ := m.Update(escMsg()) + m = updated.(Model) + + if m.Step() != stepScanCIDR { + t.Errorf("step = %d, want stepScanCIDR", m.Step()) + } + if !cancelled { + t.Error("scan context should have been cancelled") + } + if m.cancelScan != nil { + t.Error("cancelScan should be nil after cancellation") + } +} + +// Esc from stepError must land on a user-actionable step, not on the +// spinner the generation was running from. + +func TestErrorBack_ReturnsToConfirm(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepGenerating + + updated, _ := m.Update(generateErrorMsg{err: fmt.Errorf("disk full")}) + m = updated.(Model) + + if m.Step() != stepError { + t.Fatalf("expected stepError, got %d", m.Step()) + } + + updated, _ = m.Update(escMsg()) + m = updated.(Model) + + if m.Step() != stepConfirm { + t.Errorf("expected Esc to return to stepConfirm (actionable), got %d", m.Step()) + } +} + +// §14 — viewDone must tell user how to exit + +func TestViewDone_ShowsExitHint(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepDone + out := m.View() + if !strings.Contains(strings.ToLower(out), "enter") || !strings.Contains(strings.ToLower(out), "q") { + t.Errorf("viewDone should mention Enter and q keys to exit, got:\n%s", out) + } +} + +// §12 — rescan must reset selectedNodes/cursor/scanWarnings + +func TestRescanResetsSelectedCursorWarnings(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepScanning + m.selectedNodes = []int{5, 7, 9} + m.cursor = 4 + m.scanWarnings = []string{"old warning"} + + updated, _ := m.Update(scanResultMsg{nodes: []wizard.NodeInfo{{IP: "10.0.0.1"}}, warnings: nil}) + m = updated.(Model) + + if len(m.selectedNodes) != 0 { + t.Errorf("selectedNodes should be reset on rescan, got %v", m.selectedNodes) + } + if m.cursor != 0 { + t.Errorf("cursor should be reset on rescan, got %d", m.cursor) + } + if len(m.scanWarnings) != 0 { + t.Errorf("scanWarnings should be replaced on rescan, got %v", m.scanWarnings) + } +} + +// §12 — scanErrorMsg should capture the step *before* stepScanning so Esc returns to CIDR input + +func TestScanError_PrevStepIsNotScanning(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + // Simulate the real flow: user entered CIDR, pressed enter → stepScanning + m.step = stepScanning + + updated, _ := m.Update(scanErrorMsg{err: fmt.Errorf("boom")}) + m = updated.(Model) + + if m.Step() != stepError { + t.Fatalf("expected stepError, got %d", m.Step()) + } + if m.prevStep == nil { + t.Fatal("prevStep must be set") + } + if *m.prevStep == stepScanning { + t.Errorf("prevStep must not be stepScanning (Esc would land on inert spinner), got stepScanning") + } + if *m.prevStep != stepScanCIDR { + t.Errorf("prevStep should be stepScanCIDR, got %d", *m.prevStep) + } +} + +// §12 — generateErrorMsg should capture stepConfirm, not stepGenerating + +func TestGenerateError_PrevStepIsConfirm(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepGenerating + + updated, _ := m.Update(generateErrorMsg{err: fmt.Errorf("fail")}) + m = updated.(Model) + + if m.prevStep == nil || *m.prevStep == stepGenerating { + t.Errorf("prevStep must not be stepGenerating (inert spinner), got %v", m.prevStep) + } + if m.prevStep != nil && *m.prevStep != stepConfirm { + t.Errorf("prevStep should be stepConfirm, got %d", *m.prevStep) + } +} + +// §13 — back-navigation from stepConfirm must preserve edits of the last node + +func TestBack_PreservesLastNodeEdits(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepConfigureNode + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1", Hostname: "cp-1"}} + m.selectedNodes = []int{0} + m.currentNodeIdx = 0 + m.prepareNodeInputs() + + // Simulate user typing — custom DNS that differs from default + m.nodeInputs[fieldRole].SetValue("controlplane") + m.nodeInputs[fieldHostname].SetValue("cp-1") + m.nodeInputs[fieldDisk].SetValue("/dev/nvme0n1") + m.nodeInputs[fieldInterface].SetValue("eth1") + m.nodeInputs[fieldAddress].SetValue("10.0.0.1/24") + m.nodeInputs[fieldGateway].SetValue("10.0.0.254") + m.nodeInputs[fieldDNS].SetValue("1.1.1.1,9.9.9.9") + + updated, _ := m.Update(enterMsg()) // → stepConfirm + m = updated.(Model) + if m.Step() != stepConfirm { + t.Fatalf("expected stepConfirm after enter, got %d", m.Step()) + } + + // Go back — user wants to tweak something + updated, _ = m.Update(escMsg()) + m = updated.(Model) + + if m.Step() != stepConfigureNode { + t.Fatalf("expected stepConfigureNode after back, got %d", m.Step()) + } + // Inputs must still carry the user's edits (they are editing, not re-entering) + if got := m.nodeInputs[fieldDNS].Value(); got != "1.1.1.1,9.9.9.9" { + t.Errorf("DNS should be preserved, got %q", got) + } + if got := m.nodeInputs[fieldGateway].Value(); got != "10.0.0.254" { + t.Errorf("Gateway should be preserved, got %q", got) + } + if got := m.nodeInputs[fieldDisk].Value(); got != "/dev/nvme0n1" { + t.Errorf("Disk should be preserved, got %q", got) + } +} + +// §13 — back from configure node (second→first) must preserve first node's edits + +func TestBack_BetweenNodes_PreservesFirstEdits(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepConfigureNode + m.discoveredNodes = []wizard.NodeInfo{ + {IP: "10.0.0.1", Hostname: "cp-1"}, + {IP: "10.0.0.2", Hostname: "cp-2"}, + } + m.selectedNodes = []int{0, 1} + m.currentNodeIdx = 0 + m.prepareNodeInputs() + + m.nodeInputs[fieldRole].SetValue("controlplane") + m.nodeInputs[fieldHostname].SetValue("cp-1") + m.nodeInputs[fieldDisk].SetValue("/dev/nvme0n1") + m.nodeInputs[fieldAddress].SetValue("10.0.0.1/24") + m.nodeInputs[fieldDNS].SetValue("1.1.1.1") + + updated, _ := m.Update(enterMsg()) // advance to node 2 + m = updated.(Model) + if m.currentNodeIdx != 1 { + t.Fatalf("expected currentNodeIdx=1, got %d", m.currentNodeIdx) + } + + // Go back to node 1 + updated, _ = m.Update(escMsg()) + m = updated.(Model) + + if m.currentNodeIdx != 0 { + t.Fatalf("currentNodeIdx = %d, want 0", m.currentNodeIdx) + } + if got := m.nodeInputs[fieldDisk].Value(); got != "/dev/nvme0n1" { + t.Errorf("first node disk should be preserved, got %q", got) + } + if got := m.nodeInputs[fieldDNS].Value(); got != "1.1.1.1" { + t.Errorf("first node DNS should be preserved, got %q", got) + } +} + +// §3 — DNS field should not be auto-prefilled with 8.8.8.8 + +func TestPrepareNodeInputs_DNSNotPrefilled(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1"}} + m.selectedNodes = []int{0} + m.currentNodeIdx = 0 + + m.prepareNodeInputs() + + if got := m.nodeInputs[fieldDNS].Value(); got != "" { + t.Errorf("DNS should not be prefilled, got %q", got) + } +} + +// §3 — role field should be a toggle (space switches between controlplane/worker) + +func TestConfigureNode_RoleToggleWithSpace(t *testing.T) { + m := New(&mockScanner{}, []string{"generic"}, nil) + m.step = stepConfigureNode + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1"}} + m.selectedNodes = []int{0} + m.currentNodeIdx = 0 + m.prepareNodeInputs() + m.nodeInputFocus = fieldRole + + initialRole := m.nodeInputs[fieldRole].Value() + updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace}) + m = updated.(Model) + + if got := m.nodeInputs[fieldRole].Value(); got == initialRole { + t.Errorf("space on role field should toggle; still %q", got) + } + if got := m.nodeInputs[fieldRole].Value(); got != "controlplane" && got != "worker" { + t.Errorf("role toggle produced invalid value %q", got) + } +} + +// §1 — when launched on an already-initialized project the wizard skips +// the preset/cluster-name collection and jumps straight to endpoint input. +// preset/name come from the on-disk Chart.yaml via the caller. + +func TestNewForExistingProject_SkipsPresetAndName(t *testing.T) { + m := NewForExistingProject(&mockScanner{}, wizard.WizardResult{ + Preset: "generic", + ClusterName: "existing", + }, nil) + + if m.Step() != stepEndpoint { + t.Errorf("existing-project wizard should start at stepEndpoint, got %d", m.Step()) + } + if m.result.Preset != "generic" { + t.Errorf("preset should be pre-set, got %q", m.result.Preset) + } + if m.result.ClusterName != "existing" { + t.Errorf("cluster name should be pre-set, got %q", m.result.ClusterName) + } +} + +// View rendering tests + +func TestViewRendersWithoutPanic(t *testing.T) { + m := New(&mockScanner{}, []string{"generic", "cozystack"}, nil) + + steps := []step{ + stepSelectPreset, stepClusterName, stepEndpoint, + stepScanCIDR, stepScanning, stepManualNodeEntry, stepDone, + } + + for _, s := range steps { + m.step = s + output := m.View() + if output == "" { + t.Errorf("View() returned empty string for step %d", s) + } + } + + // Error view + m.step = stepError + m.err = fmt.Errorf("test error") + if m.View() == "" { + t.Error("View() returned empty for error step") + } + + // Select nodes view + m.step = stepSelectNodes + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1", Hostname: "node-01"}} + if m.View() == "" { + t.Error("View() returned empty for selectNodes step") + } + + // Configure node view + m.step = stepConfigureNode + m.discoveredNodes = []wizard.NodeInfo{{IP: "10.0.0.1"}} + m.selectedNodes = []int{0} + m.currentNodeIdx = 0 + if m.View() == "" { + t.Error("View() returned empty for configureNode step") + } + + // Confirm view + m.step = stepConfirm + m.result = wizard.WizardResult{ + Preset: "generic", + ClusterName: "test", + Endpoint: "https://10.0.0.1:6443", + Nodes: []wizard.NodeConfig{{Hostname: "cp-1", Role: "controlplane", DNS: []string{"8.8.8.8"}}}, + } + if m.View() == "" { + t.Error("View() returned empty for confirm step") + } +} diff --git a/pkg/wizard/tui/styles.go b/pkg/wizard/tui/styles.go new file mode 100644 index 0000000..9bdeed2 --- /dev/null +++ b/pkg/wizard/tui/styles.go @@ -0,0 +1,36 @@ +package tui + +import "github.com/charmbracelet/lipgloss" + +var ( + titleStyle = lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color("170")). + MarginBottom(1) + + subtitleStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("241")). + MarginBottom(1) + + focusedStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("205")) + + blurredStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("240")) + + errorStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("196")). + Bold(true) + + successStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("46")). + Bold(true) + + helpStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("241")). + MarginTop(1) + + selectedStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("170")). + Bold(true) +) diff --git a/pkg/wizard/tui/views.go b/pkg/wizard/tui/views.go new file mode 100644 index 0000000..11701a7 --- /dev/null +++ b/pkg/wizard/tui/views.go @@ -0,0 +1,272 @@ +package tui + +import ( + "fmt" + "strings" + + "github.com/dustin/go-humanize" +) + +// View implements tea.Model. +func (m Model) View() string { + switch m.step { + case stepSelectPreset: + return m.viewSelectPreset() + case stepClusterName: + return m.viewClusterName() + case stepEndpoint: + return m.viewEndpoint() + case stepScanCIDR: + return m.viewScanCIDR() + case stepScanning: + return m.viewScanning() + case stepManualNodeEntry: + return m.viewManualNodeEntry() + case stepSelectNodes: + return m.viewSelectNodes() + case stepConfigureNode: + return m.viewConfigureNode() + case stepConfirm: + return m.viewConfirm() + case stepGenerating: + return m.viewGenerating() + case stepDone: + return m.viewDone() + case stepError: + return m.viewError() + default: + return "" + } +} + +func (m Model) viewSelectPreset() string { + var b strings.Builder + b.WriteString(titleStyle.Render("Select a preset")) + b.WriteString("\n\n") + + for i, preset := range m.presets { + cursor := " " + style := blurredStyle + if i == m.cursor { + cursor = "> " + style = selectedStyle + } + b.WriteString(cursor + style.Render(preset) + "\n") + } + + b.WriteString(helpStyle.Render("\nup/down navigate | enter select | ctrl+c quit")) + return b.String() +} + +func (m Model) viewClusterName() string { + var b strings.Builder + b.WriteString(titleStyle.Render("Cluster name")) + b.WriteString("\n\n") + b.WriteString(m.nameInput.View()) + + if m.err != nil { + b.WriteString("\n" + errorStyle.Render(m.err.Error())) + } + + b.WriteString(helpStyle.Render("\nenter confirm | esc back")) + return b.String() +} + +func (m Model) viewEndpoint() string { + var b strings.Builder + b.WriteString(titleStyle.Render("API server endpoint")) + b.WriteString("\n\n") + b.WriteString(m.endpointInput.View()) + + if m.err != nil { + b.WriteString("\n" + errorStyle.Render(m.err.Error())) + } + + b.WriteString(helpStyle.Render("\nenter confirm | esc back")) + return b.String() +} + +func (m Model) viewScanCIDR() string { + var b strings.Builder + b.WriteString(titleStyle.Render("Network to scan")) + b.WriteString("\n") + b.WriteString(subtitleStyle.Render("Enter CIDR range to discover Talos nodes, or press Ctrl+S to enter IPs manually")) + b.WriteString("\n\n") + b.WriteString(m.cidrInput.View()) + + if m.err != nil { + b.WriteString("\n" + errorStyle.Render(m.err.Error())) + } + + b.WriteString(helpStyle.Render("\nenter scan | ctrl+s skip scan (manual entry) | esc back")) + return b.String() +} + +func (m Model) viewScanning() string { + return titleStyle.Render("Scanning network...") + "\n\n" + + m.spinner.View() + " Discovering Talos nodes...\n" +} + +func (m Model) viewManualNodeEntry() string { + var b strings.Builder + b.WriteString(titleStyle.Render("Manual node entry")) + b.WriteString("\n") + b.WriteString(subtitleStyle.Render("Enter node IP addresses one by one")) + b.WriteString("\n\n") + + if len(m.manualNodes) > 0 { + b.WriteString("Added nodes:\n") + for _, n := range m.manualNodes { + b.WriteString(" " + successStyle.Render(n.IP) + "\n") + } + b.WriteString("\n") + } + + b.WriteString(m.manualIPInput.View()) + + if m.err != nil { + b.WriteString("\n" + errorStyle.Render(m.err.Error())) + } + + b.WriteString(helpStyle.Render("\nenter add node | ctrl+d done | esc back")) + return b.String() +} + +func (m Model) viewSelectNodes() string { + var b strings.Builder + b.WriteString(titleStyle.Render("Select nodes")) + fmt.Fprintf(&b, "\n%d node(s) discovered\n\n", len(m.discoveredNodes)) + + for i, node := range m.discoveredNodes { + cursor := " " + if i == m.cursor { + cursor = "> " + } + + selected := "[ ]" + for _, idx := range m.selectedNodes { + if idx == i { + selected = "[x]" + break + } + } + + info := node.IP + if node.Hostname != "" { + info += " (" + node.Hostname + ")" + } + if node.RAMBytes > 0 { + info += " " + humanize.IBytes(node.RAMBytes) + " RAM" + } + if len(node.Disks) > 0 { + info += " " + node.Disks[0].Model + } + + fmt.Fprintf(&b, "%s%s %s\n", cursor, selected, info) + } + + if len(m.scanWarnings) > 0 { + b.WriteString("\n" + errorStyle.Render(fmt.Sprintf("%d node(s) found but failed gRPC:", len(m.scanWarnings)))) + for _, w := range m.scanWarnings { + b.WriteString("\n " + blurredStyle.Render(w)) + } + b.WriteString("\n") + } + + if m.err != nil { + b.WriteString("\n" + errorStyle.Render(m.err.Error())) + } + + b.WriteString(helpStyle.Render("\nup/down navigate | space toggle | enter confirm | esc back")) + return b.String() +} + +func (m Model) viewConfigureNode() string { + var b strings.Builder + nodeIdx := m.selectedNodes[m.currentNodeIdx] + node := m.discoveredNodes[nodeIdx] + + b.WriteString(titleStyle.Render(fmt.Sprintf("Configure node %d/%d", m.currentNodeIdx+1, len(m.selectedNodes)))) + fmt.Fprintf(&b, "\nIP: %s\n\n", node.IP) + + labels := []string{ + "Role:", + "Hostname:", + "Install disk:", + "Interface:", + "Address (CIDR):", + "Gateway:", + "DNS (comma-sep):", + "Management IP:", + } + for i, label := range labels { + style := blurredStyle + if i == m.nodeInputFocus { + style = focusedStyle + } + b.WriteString(style.Render(label) + " " + m.nodeInputs[i].View() + "\n") + } + + if m.err != nil { + b.WriteString("\n" + errorStyle.Render(m.err.Error())) + } + + b.WriteString(helpStyle.Render("\ntab next field | enter confirm | esc back")) + return b.String() +} + +func (m Model) viewConfirm() string { + var b strings.Builder + b.WriteString(titleStyle.Render("Confirm configuration")) + b.WriteString("\n\n") + fmt.Fprintf(&b, "Preset: %s\n", m.result.Preset) + fmt.Fprintf(&b, "Cluster: %s\n", m.result.ClusterName) + fmt.Fprintf(&b, "Endpoint: %s\n", m.result.Endpoint) + fmt.Fprintf(&b, "Nodes: %d\n", len(m.result.Nodes)) + + for i, node := range m.result.Nodes { + fmt.Fprintf(&b, "\n %d. %s [%s]\n", i+1, node.Hostname, node.Role) + fmt.Fprintf(&b, " address: %s\n", node.Addresses) + if node.Gateway != "" { + fmt.Fprintf(&b, " gateway: %s\n", node.Gateway) + } + fmt.Fprintf(&b, " disk: %s\n", node.DiskPath) + if node.Interface != "" { + fmt.Fprintf(&b, " iface: %s\n", node.Interface) + } + if len(node.DNS) > 0 { + fmt.Fprintf(&b, " DNS: %s\n", strings.Join(node.DNS, ", ")) + } + if node.ManagementIP != "" { + fmt.Fprintf(&b, " mgmt IP: %s\n", node.ManagementIP) + } + } + + b.WriteString(helpStyle.Render("\ny/enter generate | n restart | esc back")) + return b.String() +} + +func (m Model) viewGenerating() string { + return titleStyle.Render("Generating configuration...") + "\n\n" + + m.spinner.View() + " Creating secrets and config files...\n" +} + +func (m Model) viewDone() string { + return successStyle.Render("Configuration generated successfully!") + "\n\n" + + "Files created in the current directory.\n" + + "Next steps:\n" + + " 1. talm template --file nodes/.yaml (render machine configs)\n" + + " 2. talm apply --file nodes/.yaml (apply to nodes)\n" + + helpStyle.Render("\nPress enter or q to exit") +} + +func (m Model) viewError() string { + var b strings.Builder + b.WriteString(errorStyle.Render("Error")) + b.WriteString("\n\n") + if m.err != nil { + b.WriteString(m.err.Error()) + } + b.WriteString(helpStyle.Render("\nr retry | enter/q quit")) + return b.String() +} diff --git a/pkg/wizard/types.go b/pkg/wizard/types.go new file mode 100644 index 0000000..83cded3 --- /dev/null +++ b/pkg/wizard/types.go @@ -0,0 +1,62 @@ +package wizard + +// NodeInfo holds hardware and network information about a discovered Talos node. +type NodeInfo struct { + IP string + Hostname string + MAC string + CPU string // human-readable, e.g. "Intel Xeon E-2236 (12 threads)" + RAMBytes uint64 + Disks []Disk + Interfaces []NetInterface + DefaultGateway string // default route next-hop discovered via COSI, if any +} + +// Disk represents a block device on a node. +type Disk struct { + DevPath string // e.g. "/dev/sda" + Model string + SizeBytes uint64 +} + +// NetInterface represents a network interface on a node. +type NetInterface struct { + Name string + MAC string + IPs []string +} + +// NodeConfig holds user-specified configuration for a single node. +type NodeConfig struct { + Hostname string + Role string // "controlplane" or "worker" + DiskPath string // install disk, e.g. "/dev/sda" + Interface string // primary network interface + Addresses string // CIDR notation, e.g. "192.168.1.10/24" + Gateway string + DNS []string + VIP string // optional, controlplane only + // ManagementIP — IP reachable from the host running talm (may differ from + // the node's own address on DNAT setups). Empty → fall back to the IP + // extracted from Addresses. + ManagementIP string +} + +// WizardResult holds all collected data from the wizard flow, +// ready to be passed to GenerateProject and values.yaml generation. +type WizardResult struct { + Preset string + ClusterName string + Endpoint string // API server endpoint, e.g. "https://192.168.0.1:6443" + Nodes []NodeConfig + + // Network configuration + PodSubnets string // e.g. "10.244.0.0/16" + ServiceSubnets string // e.g. "10.96.0.0/16" + AdvertisedSubnets string // e.g. "192.168.100.0/24" + + // Cozystack-specific fields + ClusterDomain string + FloatingIP string + Image string +} diff --git a/pkg/wizard/validator.go b/pkg/wizard/validator.go new file mode 100644 index 0000000..b44b89c --- /dev/null +++ b/pkg/wizard/validator.go @@ -0,0 +1,104 @@ +package wizard + +import ( + "fmt" + "net" + "net/url" + "regexp" + "strings" +) + +var clusterNameRegexp = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]*[a-z0-9])?$`) + +// ValidateClusterName checks that name is a valid DNS label: +// lowercase alphanumeric and hyphens, max 63 chars, no leading/trailing hyphens. +func ValidateClusterName(name string) error { + if name == "" { + return fmt.Errorf("cluster name must not be empty") + } + if len(name) > 63 { + return fmt.Errorf("cluster name must be at most 63 characters, got %d", len(name)) + } + if !clusterNameRegexp.MatchString(name) { + return fmt.Errorf("cluster name must contain only lowercase letters, numbers, and hyphens, and must not start or end with a hyphen") + } + return nil +} + +var hostnameRegexp = regexp.MustCompile(`^[a-zA-Z0-9]([a-zA-Z0-9-]*[a-zA-Z0-9])?$`) + +// ValidateHostname checks that hostname is a valid single-label hostname (no dots). +// FQDNs are not accepted — Talos nodes use single-label hostnames. +func ValidateHostname(hostname string) error { + if hostname == "" { + return fmt.Errorf("hostname must not be empty") + } + if len(hostname) > 63 { + return fmt.Errorf("hostname label must be at most 63 characters, got %d", len(hostname)) + } + if !hostnameRegexp.MatchString(hostname) { + return fmt.Errorf("hostname must contain only letters, numbers, and hyphens, and must not start or end with a hyphen") + } + return nil +} + +// ValidateCIDR checks that cidr is a valid CIDR notation (e.g. "192.168.1.0/24"). +func ValidateCIDR(cidr string) error { + if cidr == "" { + return fmt.Errorf("CIDR must not be empty") + } + _, _, err := net.ParseCIDR(cidr) + if err != nil { + return fmt.Errorf("invalid CIDR notation: %w", err) + } + return nil +} + +// ValidateEndpoint checks that endpoint is a valid https URL with a host and port. +// Example: "https://192.168.0.1:6443" +func ValidateEndpoint(endpoint string) error { + if endpoint == "" { + return fmt.Errorf("endpoint must not be empty") + } + if !strings.HasPrefix(endpoint, "https://") { + return fmt.Errorf("endpoint must start with https:// (e.g. https://192.168.0.1:6443)") + } + u, err := url.Parse(endpoint) + if err != nil || u.Host == "" { + return fmt.Errorf("invalid endpoint URL: %s", endpoint) + } + // url.Parse("https://:6443") yields Host=":6443" but Hostname()="". + // Reject explicitly so endpoints always carry a usable host/IP. + if u.Hostname() == "" { + return fmt.Errorf("endpoint must include a valid hostname or IP: %s", endpoint) + } + if u.Scheme != "https" { + return fmt.Errorf("endpoint must use https scheme, got %q", u.Scheme) + } + if u.Port() == "" { + return fmt.Errorf("endpoint must include a port number") + } + return nil +} + +// ValidateIP checks that ip is a valid IP address (v4 or v6). +func ValidateIP(ip string) error { + if ip == "" { + return fmt.Errorf("IP address must not be empty") + } + parsed := net.ParseIP(ip) + if parsed == nil { + return fmt.Errorf("invalid IP address: %s", ip) + } + return nil +} + +// ValidateNodeRole checks that role is either "controlplane" or "worker". +func ValidateNodeRole(role string) error { + switch role { + case "controlplane", "worker": + return nil + default: + return fmt.Errorf("node role must be %q or %q, got %q", "controlplane", "worker", role) + } +} diff --git a/pkg/wizard/validator_test.go b/pkg/wizard/validator_test.go new file mode 100644 index 0000000..8d6d987 --- /dev/null +++ b/pkg/wizard/validator_test.go @@ -0,0 +1,174 @@ +package wizard + +import ( + "strings" + "testing" +) + +func TestValidateClusterName(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"valid simple", "my-cluster", false}, + {"valid with numbers", "cluster-01", false}, + {"valid single word", "test", false}, + {"empty", "", true}, + {"uppercase", "MyCluster", true}, + {"starts with dash", "-cluster", true}, + {"ends with dash", "cluster-", true}, + {"contains underscore", "my_cluster", true}, + {"contains space", "my cluster", true}, + {"contains dot", "my.cluster", true}, + {"too long", strings.Repeat("a", 64), true}, + {"max valid length", strings.Repeat("a", 63), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateClusterName(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateClusterName(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestValidateHostname(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"valid", "node-01", false}, + {"valid short", "n", false}, + {"empty", "", true}, + {"uppercase allowed", "Node01", false}, + {"starts with dash", "-node", true}, + {"ends with dash", "node-", true}, + {"contains space", "my node", true}, + {"too long", strings.Repeat("a", 64), true}, + {"max valid length", strings.Repeat("a", 63), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateHostname(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateHostname(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestValidateCIDR(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"valid /24", "192.168.1.0/24", false}, + {"valid /16", "10.0.0.0/16", false}, + {"valid /32", "10.0.0.1/32", false}, + {"empty", "", true}, + {"no mask", "192.168.1.0", true}, + {"invalid ip", "999.999.999.999/24", true}, + {"invalid mask", "192.168.1.0/33", true}, + {"garbage", "not-a-cidr", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateCIDR(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateCIDR(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestValidateEndpoint(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"valid https with port", "https://192.168.0.1:6443", false}, + {"valid https with hostname", "https://api.example.com:6443", false}, + {"empty", "", true}, + {"no scheme", "192.168.0.1:6443", true}, + {"http scheme", "http://192.168.0.1:6443", true}, + {"no port", "https://192.168.0.1", true}, + {"garbage", "not-a-url", true}, + {"hostname-only port", "https://:6443", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateEndpoint(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateEndpoint(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestValidateEndpoint_ErrorMentionsHTTPS(t *testing.T) { + err := ValidateEndpoint("192.168.0.1:6443") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "https://") { + t.Errorf("error should mention https://, got: %v", err) + } +} + +func TestValidateIP(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"valid ipv4", "192.168.1.1", false}, + {"valid ipv6", "::1", false}, + {"valid ipv6 full", "2001:db8::1", false}, + {"empty", "", true}, + {"invalid", "not-an-ip", true}, + {"cidr notation", "192.168.1.0/24", true}, + {"out of range", "256.1.1.1", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateIP(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateIP(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestValidateNodeRole(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"controlplane", "controlplane", false}, + {"worker", "worker", false}, + {"empty", "", true}, + {"master", "master", true}, + {"uppercase", "Controlplane", true}, + {"unknown", "other", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateNodeRole(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateNodeRole(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +}