diff --git a/pkg/espflasher/flasher.go b/pkg/espflasher/flasher.go index 32ea28a..197fbc9 100644 --- a/pkg/espflasher/flasher.go +++ b/pkg/espflasher/flasher.go @@ -7,6 +7,7 @@ import ( "encoding/hex" "fmt" "io" + "runtime" "time" "go.bug.st/serial" @@ -267,10 +268,26 @@ func (f *Flasher) connect() error { // Reset the chip into bootloader mode switch f.opts.ResetMode { case ResetDefault: - if attempt%2 == 0 { - classicReset(f.port, defaultResetDelay) + // On Unix systems (darwin, linux), use the esptool sequence: + // unixTightReset with 50ms, then 550ms, then fallback to classic resets + if runtime.GOOS == "darwin" || runtime.GOOS == "linux" { + switch attempt % 4 { + case 0: + unixTightReset(f.port, defaultResetDelay) + case 1: + unixTightReset(f.port, extraResetDelay) + case 2: + classicReset(f.port, defaultResetDelay) + case 3: + classicReset(f.port, extraResetDelay) + } } else { - tightReset(f.port, defaultResetDelay) + // On other systems, use classic and tight reset alternation + if attempt%2 == 0 { + classicReset(f.port, defaultResetDelay) + } else { + tightReset(f.port, defaultResetDelay) + } } case ResetUSBJTAG: if f.connectViaUSBJTAG() { diff --git a/pkg/espflasher/reset.go b/pkg/espflasher/reset.go index e4a2227..c1df86b 100644 --- a/pkg/espflasher/reset.go +++ b/pkg/espflasher/reset.go @@ -29,12 +29,33 @@ const ( const ( // defaultResetDelay is the standard delay during reset sequences. - defaultResetDelay = 100 * time.Millisecond + // Matches esptool.py's DEFAULT_RESET_DELAY. + defaultResetDelay = 50 * time.Millisecond - // tightResetDelay is a shorter delay for Unix systems. - tightResetDelay = 50 * time.Millisecond + // extraResetDelay is used for longer-duration reset cycles on some devices. + extraResetDelay = 550 * time.Millisecond ) +// setDTRandRTS sets DTR and RTS simultaneously. +// On Unix systems (darwin, linux), uses atomic TIOCMSET ioctl to set both lines +// in a single operation. This is important for CH340 boards that require precise +// timing of the DTR/RTS transition. +// On other systems, falls back to separate SetDTR and SetRTS calls. +func setDTRandRTS(port serial.Port, dtr, rts bool) error { + // On Unix systems, try to use atomic TIOCMSET for precise timing + err := setDTRandRTSAtomic(port, dtr, rts) + if err == nil { + // Atomic operation succeeded + return nil + } + + // Fallback: use separate calls (always used on non-Unix, used on Unix if atomic fails) + if err := port.SetDTR(dtr); err != nil { + return err + } + return port.SetRTS(rts) +} + // classicReset performs the classic DTR/RTS bootloader entry sequence. // // This is the standard reset sequence used by most USB-UART bridges: @@ -59,12 +80,38 @@ func classicReset(port serial.Port, delay time.Duration) { // IO0=LOW (request bootloader), EN=HIGH (release reset) port.SetDTR(true) //nolint:errcheck port.SetRTS(false) //nolint:errcheck - time.Sleep(tightResetDelay) + time.Sleep(defaultResetDelay) // IO0=HIGH (release GPIO0) port.SetDTR(false) //nolint:errcheck } +// unixTightReset performs the esptool.py UnixTightReset sequence. +// This uses atomic DTR/RTS transitions for precise timing on Unix systems. +// It matches esptool.py's reset sequence to better support CH340 boards on macOS/Linux. +// +// Sequence (using atomic setDTRandRTS where possible): +// 1. setDTRandRTS(false, false) - IO0=HIGH, EN=HIGH (idle) +// 2. setDTRandRTS(true, true) - IO0=LOW, EN=LOW +// 3. setDTRandRTS(false, true) - IO0=HIGH, EN=LOW (chip held in reset) +// 4. Sleep 100ms +// 5. setDTRandRTS(true, false) - IO0=LOW, EN=HIGH (bootloader mode) +// 6. Sleep delay ms +// 7. setDTRandRTS(false, false) - IO0=HIGH, EN=HIGH +// 8. SetDTR(false) - ensure IO0 is released +func unixTightReset(port serial.Port, delay time.Duration) { + setDTRandRTS(port, false, false) //nolint:errcheck + setDTRandRTS(port, true, true) //nolint:errcheck + setDTRandRTS(port, false, true) //nolint:errcheck + time.Sleep(100 * time.Millisecond) + + setDTRandRTS(port, true, false) //nolint:errcheck + time.Sleep(delay) + + setDTRandRTS(port, false, false) //nolint:errcheck + port.SetDTR(false) //nolint:errcheck +} + // tightReset performs a tighter reset timing variant. // Some Linux serial drivers need DTR and RTS set simultaneously. func tightReset(port serial.Port, delay time.Duration) { @@ -80,7 +127,7 @@ func tightReset(port serial.Port, delay time.Duration) { // Release: IO0=LOW (bootloader), EN=HIGH (run) port.SetDTR(false) //nolint:errcheck port.SetRTS(false) //nolint:errcheck - time.Sleep(tightResetDelay) + time.Sleep(defaultResetDelay) port.SetDTR(false) //nolint:errcheck } diff --git a/pkg/espflasher/reset_nonunix.go b/pkg/espflasher/reset_nonunix.go new file mode 100644 index 0000000..b77d4e4 --- /dev/null +++ b/pkg/espflasher/reset_nonunix.go @@ -0,0 +1,13 @@ +//go:build !darwin && !linux + +package espflasher + +import "errors" + +var errNoHandle = errors.New("atomic TIOCMSET not available on this platform") + +// setDTRandRTSAtomic is not available on non-Unix platforms. +// Non-Unix systems will always fall back to separate SetDTR/SetRTS calls. +func setDTRandRTSAtomic(port interface{}, dtr, rts bool) error { + return errNoHandle +} diff --git a/pkg/espflasher/reset_test.go b/pkg/espflasher/reset_test.go new file mode 100644 index 0000000..912c144 --- /dev/null +++ b/pkg/espflasher/reset_test.go @@ -0,0 +1,163 @@ +package espflasher + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.bug.st/serial" +) + +// recordingPort tracks all calls to SetDTR and SetRTS for testing. +// Each call is recorded as a separate event to allow testing the order and +// combinations of line state transitions. +type recordingPort struct { + dtrCalls []bool + rtsCalls []bool +} + +func (r *recordingPort) SetDTR(dtr bool) error { + r.dtrCalls = append(r.dtrCalls, dtr) + return nil +} + +func (r *recordingPort) SetRTS(rts bool) error { + r.rtsCalls = append(r.rtsCalls, rts) + return nil +} + +func (r *recordingPort) Write(p []byte) (int, error) { return len(p), nil } +func (r *recordingPort) Read(p []byte) (int, error) { return 0, nil } +func (r *recordingPort) SetMode(mode *serial.Mode) error { return nil } +func (r *recordingPort) SetReadTimeout(t time.Duration) error { return nil } +func (r *recordingPort) SetWriteTimeout(t time.Duration) error { return nil } +func (r *recordingPort) Close() error { return nil } +func (r *recordingPort) ResetInputBuffer() error { return nil } +func (r *recordingPort) ResetOutputBuffer() error { return nil } +func (r *recordingPort) GetModemStatusBits() (*serial.ModemStatusBits, error) { return nil, nil } +func (r *recordingPort) Break(t time.Duration) error { return nil } +func (r *recordingPort) Drain() error { return nil } + +// TestClassicReset verifies the classic reset sequence. +func TestClassicReset(t *testing.T) { + port := &recordingPort{} + classicReset(port, defaultResetDelay) + + // Classic reset sequence: + // 1. SetDTR(false), SetRTS(true) - IO0=HIGH, EN=LOW (hold in reset) + // 2. SetDTR(true), SetRTS(false) - IO0=LOW, EN=HIGH (bootloader) + // 3. SetDTR(false) - IO0=HIGH + + // Verify we have DTR calls (SetDTR is called 3 times) + assert.GreaterOrEqual(t, len(port.dtrCalls), 2, "should call SetDTR multiple times") + + // Verify we have RTS calls (SetRTS is called 2 times) + assert.GreaterOrEqual(t, len(port.rtsCalls), 2, "should call SetRTS multiple times") + + // Verify first DTR is false (IO0=HIGH) + assert.Equal(t, false, port.dtrCalls[0], "first SetDTR should be false") + + // Verify first RTS is true (EN=LOW, chip held in reset) + assert.Equal(t, true, port.rtsCalls[0], "first SetRTS should be true") + + // Verify second DTR is true (IO0=LOW for bootloader mode) + assert.Equal(t, true, port.dtrCalls[1], "second SetDTR should be true") + + // Verify second RTS is false (EN=HIGH, release reset) + assert.Equal(t, false, port.rtsCalls[1], "second SetRTS should be false") +} + +// TestUnixTightReset verifies the Unix tight reset sequence. +func TestUnixTightReset(t *testing.T) { + port := &recordingPort{} + unixTightReset(port, defaultResetDelay) + + // UnixTightReset sequence using setDTRandRTS: + // 1. setDTRandRTS(false, false) - IO0=HIGH, EN=HIGH + // 2. setDTRandRTS(true, true) - IO0=LOW, EN=LOW + // 3. setDTRandRTS(false, true) - IO0=HIGH, EN=LOW + // 4. setDTRandRTS(true, false) - IO0=LOW, EN=HIGH (bootloader mode) + // 5. setDTRandRTS(false, false) - IO0=HIGH, EN=HIGH + // 6. SetDTR(false) + + // Should have multiple DTR and RTS calls + assert.GreaterOrEqual(t, len(port.dtrCalls), 4, "should have multiple DTR calls") + assert.GreaterOrEqual(t, len(port.rtsCalls), 4, "should have multiple RTS calls") + + // Verify bootloader mode is reached (DTR=true, RTS=false at indices matching) + // Look for the pattern where DTR becomes true before RTS becomes false + dtrTrueIdx := -1 + rtsFalseIdx := -1 + for i, val := range port.dtrCalls { + if val && dtrTrueIdx == -1 { + dtrTrueIdx = i + } + } + for i, val := range port.rtsCalls { + if !val && rtsFalseIdx == -1 { + rtsFalseIdx = i + } + } + + assert.NotEqual(t, -1, dtrTrueIdx, "should set DTR=true") + assert.NotEqual(t, -1, rtsFalseIdx, "should set RTS=false") +} + +// TestTightReset verifies the tight reset sequence. +func TestTightReset(t *testing.T) { + port := &recordingPort{} + tightReset(port, defaultResetDelay) + + // TightReset sequence: + // 1. SetDTR(false), SetRTS(false) - IO0=HIGH, EN=HIGH + // 2. SetDTR(true), SetRTS(true) - IO0=LOW, EN=LOW + // 3. SetDTR(false), SetRTS(false) - IO0=HIGH, EN=HIGH (release) + + assert.GreaterOrEqual(t, len(port.dtrCalls), 2, "should have at least 2 DTR calls") + assert.GreaterOrEqual(t, len(port.rtsCalls), 2, "should have at least 2 RTS calls") + + // Verify initial state + assert.Equal(t, false, port.dtrCalls[0], "first SetDTR should be false") + assert.Equal(t, false, port.rtsCalls[0], "first SetRTS should be false") +} + +// TestSetDTRandRTS verifies the setDTRandRTS helper. +func TestSetDTRandRTS(t *testing.T) { + port := &recordingPort{} + err := setDTRandRTS(port, true, false) + require.NoError(t, err) + + // Should call SetDTR(true) and SetRTS(false) + assert.True(t, len(port.dtrCalls) > 0, "should call SetDTR") + assert.True(t, len(port.rtsCalls) > 0, "should call SetRTS") + assert.Equal(t, true, port.dtrCalls[len(port.dtrCalls)-1], "last DTR call should be true") + assert.Equal(t, false, port.rtsCalls[len(port.rtsCalls)-1], "last RTS call should be false") +} + +// TestSetDTRandRTSBothTrue verifies setting both high. +func TestSetDTRandRTSBothTrue(t *testing.T) { + port := &recordingPort{} + err := setDTRandRTS(port, true, true) + require.NoError(t, err) + + assert.Equal(t, true, port.dtrCalls[len(port.dtrCalls)-1], "last DTR call should be true") + assert.Equal(t, true, port.rtsCalls[len(port.rtsCalls)-1], "last RTS call should be true") +} + +// TestSetDTRandRTSBothFalse verifies setting both low. +func TestSetDTRandRTSBothFalse(t *testing.T) { + port := &recordingPort{} + err := setDTRandRTS(port, false, false) + require.NoError(t, err) + + assert.Equal(t, false, port.dtrCalls[len(port.dtrCalls)-1], "last DTR call should be false") + assert.Equal(t, false, port.rtsCalls[len(port.rtsCalls)-1], "last RTS call should be false") +} + +// TestResetDelayConstants verifies the reset delay constants. +func TestResetDelayConstants(t *testing.T) { + // Verify constants match esptool expectations + assert.Equal(t, 50*time.Millisecond, defaultResetDelay, "defaultResetDelay should be 50ms") + assert.Equal(t, 550*time.Millisecond, extraResetDelay, "extraResetDelay should be 550ms") +} diff --git a/pkg/espflasher/reset_unix.go b/pkg/espflasher/reset_unix.go new file mode 100644 index 0000000..a42a661 --- /dev/null +++ b/pkg/espflasher/reset_unix.go @@ -0,0 +1,47 @@ +//go:build darwin || linux + +package espflasher + +import ( + "errors" + "reflect" + + "golang.org/x/sys/unix" +) + +var errNoHandle = errors.New("port does not have accessible file descriptor") + +// setDTRandRTSAtomic performs an atomic TIOCMSET ioctl on Unix systems. +// It attempts to use the underlying file descriptor to set DTR and RTS +// simultaneously, which is important for precise timing on CH340 boards. +// Returns errNoHandle if the port doesn't have an accessible fd (signals to try separate calls). +func setDTRandRTSAtomic(port interface{}, dtr, rts bool) error { + // Try to get the underlying file descriptor via type assertion. + // go.bug.st/serial's unixPort has a handle field (int), but it's unexported. + // We'll use reflect to access it. + v := reflect.ValueOf(port) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + // Look for a field named "handle" (fd) + handleField := v.FieldByName("handle") + if !handleField.IsValid() || handleField.Kind() != reflect.Int { + // Field doesn't exist or wrong type; caller will fall back to separate calls + return errNoHandle + } + + fd := int(handleField.Int()) + + // Build the TIOCMSET bitmask + var status int + if dtr { + status |= unix.TIOCM_DTR + } + if rts { + status |= unix.TIOCM_RTS + } + + // Perform the atomic ioctl + return unix.IoctlSetPointerInt(fd, unix.TIOCMSET, status) +}