Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions pkg/espflasher/flasher.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/hex"
"fmt"
"io"
"runtime"
"time"

"go.bug.st/serial"
Expand Down Expand Up @@ -267,10 +268,26 @@ func (f *Flasher) connect() error {
// Reset the chip into bootloader mode
switch f.opts.ResetMode {
case ResetDefault:
if attempt%2 == 0 {
classicReset(f.port, defaultResetDelay)
// On Unix systems (darwin, linux), use the esptool sequence:
// unixTightReset with 50ms, then 550ms, then fallback to classic resets
if runtime.GOOS == "darwin" || runtime.GOOS == "linux" {
switch attempt % 4 {
case 0:
unixTightReset(f.port, defaultResetDelay)
case 1:
unixTightReset(f.port, extraResetDelay)
case 2:
classicReset(f.port, defaultResetDelay)
case 3:
classicReset(f.port, extraResetDelay)
}
} else {
tightReset(f.port, defaultResetDelay)
// On other systems, use classic and tight reset alternation
if attempt%2 == 0 {
classicReset(f.port, defaultResetDelay)
} else {
tightReset(f.port, defaultResetDelay)
}
}
case ResetUSBJTAG:
if f.connectViaUSBJTAG() {
Expand Down
57 changes: 52 additions & 5 deletions pkg/espflasher/reset.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,33 @@ const (

const (
// defaultResetDelay is the standard delay during reset sequences.
defaultResetDelay = 100 * time.Millisecond
// Matches esptool.py's DEFAULT_RESET_DELAY.
defaultResetDelay = 50 * time.Millisecond

// tightResetDelay is a shorter delay for Unix systems.
tightResetDelay = 50 * time.Millisecond
// extraResetDelay is used for longer-duration reset cycles on some devices.
extraResetDelay = 550 * time.Millisecond
)

// setDTRandRTS sets DTR and RTS simultaneously.
// On Unix systems (darwin, linux), uses atomic TIOCMSET ioctl to set both lines
// in a single operation. This is important for CH340 boards that require precise
// timing of the DTR/RTS transition.
// On other systems, falls back to separate SetDTR and SetRTS calls.
func setDTRandRTS(port serial.Port, dtr, rts bool) error {
// On Unix systems, try to use atomic TIOCMSET for precise timing
err := setDTRandRTSAtomic(port, dtr, rts)
if err == nil {
// Atomic operation succeeded
return nil
}

// Fallback: use separate calls (always used on non-Unix, used on Unix if atomic fails)
if err := port.SetDTR(dtr); err != nil {
return err
}
return port.SetRTS(rts)
}

// classicReset performs the classic DTR/RTS bootloader entry sequence.
//
// This is the standard reset sequence used by most USB-UART bridges:
Expand All @@ -59,12 +80,38 @@ func classicReset(port serial.Port, delay time.Duration) {
// IO0=LOW (request bootloader), EN=HIGH (release reset)
port.SetDTR(true) //nolint:errcheck
port.SetRTS(false) //nolint:errcheck
time.Sleep(tightResetDelay)
time.Sleep(defaultResetDelay)

// IO0=HIGH (release GPIO0)
port.SetDTR(false) //nolint:errcheck
}

// unixTightReset performs the esptool.py UnixTightReset sequence.
// This uses atomic DTR/RTS transitions for precise timing on Unix systems.
// It matches esptool.py's reset sequence to better support CH340 boards on macOS/Linux.
//
// Sequence (using atomic setDTRandRTS where possible):
// 1. setDTRandRTS(false, false) - IO0=HIGH, EN=HIGH (idle)
// 2. setDTRandRTS(true, true) - IO0=LOW, EN=LOW
// 3. setDTRandRTS(false, true) - IO0=HIGH, EN=LOW (chip held in reset)
// 4. Sleep 100ms
// 5. setDTRandRTS(true, false) - IO0=LOW, EN=HIGH (bootloader mode)
// 6. Sleep delay ms
// 7. setDTRandRTS(false, false) - IO0=HIGH, EN=HIGH
// 8. SetDTR(false) - ensure IO0 is released
func unixTightReset(port serial.Port, delay time.Duration) {
setDTRandRTS(port, false, false) //nolint:errcheck
setDTRandRTS(port, true, true) //nolint:errcheck
setDTRandRTS(port, false, true) //nolint:errcheck
time.Sleep(100 * time.Millisecond)

setDTRandRTS(port, true, false) //nolint:errcheck
time.Sleep(delay)

setDTRandRTS(port, false, false) //nolint:errcheck
port.SetDTR(false) //nolint:errcheck
}

// tightReset performs a tighter reset timing variant.
// Some Linux serial drivers need DTR and RTS set simultaneously.
func tightReset(port serial.Port, delay time.Duration) {
Expand All @@ -80,7 +127,7 @@ func tightReset(port serial.Port, delay time.Duration) {
// Release: IO0=LOW (bootloader), EN=HIGH (run)
port.SetDTR(false) //nolint:errcheck
port.SetRTS(false) //nolint:errcheck
time.Sleep(tightResetDelay)
time.Sleep(defaultResetDelay)

port.SetDTR(false) //nolint:errcheck
}
Expand Down
13 changes: 13 additions & 0 deletions pkg/espflasher/reset_nonunix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//go:build !darwin && !linux

package espflasher

import "errors"

var errNoHandle = errors.New("atomic TIOCMSET not available on this platform")

// setDTRandRTSAtomic is not available on non-Unix platforms.
// Non-Unix systems will always fall back to separate SetDTR/SetRTS calls.
func setDTRandRTSAtomic(port interface{}, dtr, rts bool) error {
return errNoHandle
}
163 changes: 163 additions & 0 deletions pkg/espflasher/reset_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package espflasher

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.bug.st/serial"
)

// recordingPort tracks all calls to SetDTR and SetRTS for testing.
// Each call is recorded as a separate event to allow testing the order and
// combinations of line state transitions.
type recordingPort struct {
dtrCalls []bool
rtsCalls []bool
}

func (r *recordingPort) SetDTR(dtr bool) error {
r.dtrCalls = append(r.dtrCalls, dtr)
return nil
}

func (r *recordingPort) SetRTS(rts bool) error {
r.rtsCalls = append(r.rtsCalls, rts)
return nil
}

func (r *recordingPort) Write(p []byte) (int, error) { return len(p), nil }
func (r *recordingPort) Read(p []byte) (int, error) { return 0, nil }
func (r *recordingPort) SetMode(mode *serial.Mode) error { return nil }
func (r *recordingPort) SetReadTimeout(t time.Duration) error { return nil }
func (r *recordingPort) SetWriteTimeout(t time.Duration) error { return nil }
func (r *recordingPort) Close() error { return nil }
func (r *recordingPort) ResetInputBuffer() error { return nil }
func (r *recordingPort) ResetOutputBuffer() error { return nil }
func (r *recordingPort) GetModemStatusBits() (*serial.ModemStatusBits, error) { return nil, nil }
func (r *recordingPort) Break(t time.Duration) error { return nil }
func (r *recordingPort) Drain() error { return nil }

// TestClassicReset verifies the classic reset sequence.
func TestClassicReset(t *testing.T) {
port := &recordingPort{}
classicReset(port, defaultResetDelay)

// Classic reset sequence:
// 1. SetDTR(false), SetRTS(true) - IO0=HIGH, EN=LOW (hold in reset)
// 2. SetDTR(true), SetRTS(false) - IO0=LOW, EN=HIGH (bootloader)
// 3. SetDTR(false) - IO0=HIGH

// Verify we have DTR calls (SetDTR is called 3 times)
assert.GreaterOrEqual(t, len(port.dtrCalls), 2, "should call SetDTR multiple times")

// Verify we have RTS calls (SetRTS is called 2 times)
assert.GreaterOrEqual(t, len(port.rtsCalls), 2, "should call SetRTS multiple times")

// Verify first DTR is false (IO0=HIGH)
assert.Equal(t, false, port.dtrCalls[0], "first SetDTR should be false")

// Verify first RTS is true (EN=LOW, chip held in reset)
assert.Equal(t, true, port.rtsCalls[0], "first SetRTS should be true")

// Verify second DTR is true (IO0=LOW for bootloader mode)
assert.Equal(t, true, port.dtrCalls[1], "second SetDTR should be true")

// Verify second RTS is false (EN=HIGH, release reset)
assert.Equal(t, false, port.rtsCalls[1], "second SetRTS should be false")
}

// TestUnixTightReset verifies the Unix tight reset sequence.
func TestUnixTightReset(t *testing.T) {
port := &recordingPort{}
unixTightReset(port, defaultResetDelay)

// UnixTightReset sequence using setDTRandRTS:
// 1. setDTRandRTS(false, false) - IO0=HIGH, EN=HIGH
// 2. setDTRandRTS(true, true) - IO0=LOW, EN=LOW
// 3. setDTRandRTS(false, true) - IO0=HIGH, EN=LOW
// 4. setDTRandRTS(true, false) - IO0=LOW, EN=HIGH (bootloader mode)
// 5. setDTRandRTS(false, false) - IO0=HIGH, EN=HIGH
// 6. SetDTR(false)

// Should have multiple DTR and RTS calls
assert.GreaterOrEqual(t, len(port.dtrCalls), 4, "should have multiple DTR calls")
assert.GreaterOrEqual(t, len(port.rtsCalls), 4, "should have multiple RTS calls")

// Verify bootloader mode is reached (DTR=true, RTS=false at indices matching)
// Look for the pattern where DTR becomes true before RTS becomes false
dtrTrueIdx := -1
rtsFalseIdx := -1
for i, val := range port.dtrCalls {
if val && dtrTrueIdx == -1 {
dtrTrueIdx = i
}
}
for i, val := range port.rtsCalls {
if !val && rtsFalseIdx == -1 {
rtsFalseIdx = i
}
}

assert.NotEqual(t, -1, dtrTrueIdx, "should set DTR=true")
assert.NotEqual(t, -1, rtsFalseIdx, "should set RTS=false")
}

// TestTightReset verifies the tight reset sequence.
func TestTightReset(t *testing.T) {
port := &recordingPort{}
tightReset(port, defaultResetDelay)

// TightReset sequence:
// 1. SetDTR(false), SetRTS(false) - IO0=HIGH, EN=HIGH
// 2. SetDTR(true), SetRTS(true) - IO0=LOW, EN=LOW
// 3. SetDTR(false), SetRTS(false) - IO0=HIGH, EN=HIGH (release)

assert.GreaterOrEqual(t, len(port.dtrCalls), 2, "should have at least 2 DTR calls")
assert.GreaterOrEqual(t, len(port.rtsCalls), 2, "should have at least 2 RTS calls")

// Verify initial state
assert.Equal(t, false, port.dtrCalls[0], "first SetDTR should be false")
assert.Equal(t, false, port.rtsCalls[0], "first SetRTS should be false")
}

// TestSetDTRandRTS verifies the setDTRandRTS helper.
func TestSetDTRandRTS(t *testing.T) {
port := &recordingPort{}
err := setDTRandRTS(port, true, false)
require.NoError(t, err)

// Should call SetDTR(true) and SetRTS(false)
assert.True(t, len(port.dtrCalls) > 0, "should call SetDTR")
assert.True(t, len(port.rtsCalls) > 0, "should call SetRTS")
assert.Equal(t, true, port.dtrCalls[len(port.dtrCalls)-1], "last DTR call should be true")
assert.Equal(t, false, port.rtsCalls[len(port.rtsCalls)-1], "last RTS call should be false")
}

// TestSetDTRandRTSBothTrue verifies setting both high.
func TestSetDTRandRTSBothTrue(t *testing.T) {
port := &recordingPort{}
err := setDTRandRTS(port, true, true)
require.NoError(t, err)

assert.Equal(t, true, port.dtrCalls[len(port.dtrCalls)-1], "last DTR call should be true")
assert.Equal(t, true, port.rtsCalls[len(port.rtsCalls)-1], "last RTS call should be true")
}

// TestSetDTRandRTSBothFalse verifies setting both low.
func TestSetDTRandRTSBothFalse(t *testing.T) {
port := &recordingPort{}
err := setDTRandRTS(port, false, false)
require.NoError(t, err)

assert.Equal(t, false, port.dtrCalls[len(port.dtrCalls)-1], "last DTR call should be false")
assert.Equal(t, false, port.rtsCalls[len(port.rtsCalls)-1], "last RTS call should be false")
}

// TestResetDelayConstants verifies the reset delay constants.
func TestResetDelayConstants(t *testing.T) {
// Verify constants match esptool expectations
assert.Equal(t, 50*time.Millisecond, defaultResetDelay, "defaultResetDelay should be 50ms")
assert.Equal(t, 550*time.Millisecond, extraResetDelay, "extraResetDelay should be 550ms")
}
47 changes: 47 additions & 0 deletions pkg/espflasher/reset_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//go:build darwin || linux

package espflasher

import (
"errors"
"reflect"

"golang.org/x/sys/unix"
)

var errNoHandle = errors.New("port does not have accessible file descriptor")

// setDTRandRTSAtomic performs an atomic TIOCMSET ioctl on Unix systems.
// It attempts to use the underlying file descriptor to set DTR and RTS
// simultaneously, which is important for precise timing on CH340 boards.
// Returns errNoHandle if the port doesn't have an accessible fd (signals to try separate calls).
func setDTRandRTSAtomic(port interface{}, dtr, rts bool) error {
// Try to get the underlying file descriptor via type assertion.
// go.bug.st/serial's unixPort has a handle field (int), but it's unexported.
// We'll use reflect to access it.
v := reflect.ValueOf(port)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}

// Look for a field named "handle" (fd)
handleField := v.FieldByName("handle")
if !handleField.IsValid() || handleField.Kind() != reflect.Int {
// Field doesn't exist or wrong type; caller will fall back to separate calls
return errNoHandle
}

fd := int(handleField.Int())

// Build the TIOCMSET bitmask
var status int
if dtr {
status |= unix.TIOCM_DTR
}
if rts {
status |= unix.TIOCM_RTS
}

// Perform the atomic ioctl
return unix.IoctlSetPointerInt(fd, unix.TIOCMSET, status)
}
Loading