diff --git a/src/cellselection.ts b/src/cellselection.ts index d4fcfdca..9d1edc5f 100644 --- a/src/cellselection.ts +++ b/src/cellselection.ts @@ -29,6 +29,21 @@ export interface CellSelectionJSON { head: number; } +/** + * Options for creating a CellSelection. + * + * @public + */ +export interface CellSelectionOptions { + /** + * When true, the selection will be expanded to form a complete rectangle, + * including all cells that span across the selection boundaries. + * This is useful for mouse drag selections to prevent T-shaped or L-shaped selections. + * Default is false. + */ + forceRectangular?: boolean; +} + /** * A [`Selection`](http://prosemirror.net/docs/ref/#state.Selection) * subclass that represents a cell selection spanning part of a table. @@ -47,17 +62,25 @@ export class CellSelection extends Selection { // moves when extending the selection). public $headCell: ResolvedPos; + public forceRectangular: boolean; + // A table selection is identified by its anchor and head cells. The // positions given to this constructor should point _before_ two // cells in the same table. They may be the same, to select a single // cell. - constructor($anchorCell: ResolvedPos, $headCell: ResolvedPos = $anchorCell) { + constructor( + $anchorCell: ResolvedPos, + $headCell: ResolvedPos = $anchorCell, + options: CellSelectionOptions = {}, + ) { + const { forceRectangular = false } = options; const table = $anchorCell.node(-1); const map = TableMap.get(table); const tableStart = $anchorCell.start(-1); const rect = map.rectBetween( $anchorCell.pos - tableStart, $headCell.pos - tableStart, + forceRectangular, ); const doc = $anchorCell.node(0); @@ -81,6 +104,7 @@ export class CellSelection extends Selection { super(ranges[0].$from, ranges[0].$to, ranges); this.$anchorCell = $anchorCell; this.$headCell = $headCell; + this.forceRectangular = forceRectangular; } public map(doc: Node, mapping: Mappable): CellSelection | Selection { @@ -96,7 +120,10 @@ export class CellSelection extends Selection { return CellSelection.rowSelection($anchorCell, $headCell); else if (tableChanged && this.isColSelection()) return CellSelection.colSelection($anchorCell, $headCell); - else return new CellSelection($anchorCell, $headCell); + else + return new CellSelection($anchorCell, $headCell, { + forceRectangular: this.forceRectangular, + }); } return TextSelection.between($anchorCell, $headCell); } @@ -111,6 +138,7 @@ export class CellSelection extends Selection { const rect = map.rectBetween( this.$anchorCell.pos - tableStart, this.$headCell.pos - tableStart, + this.forceRectangular, ); const seen: Record = {}; const rows = []; @@ -212,6 +240,7 @@ export class CellSelection extends Selection { map.rectBetween( this.$anchorCell.pos - tableStart, this.$headCell.pos - tableStart, + this.forceRectangular, ), ); for (let i = 0; i < cells.length; i++) { @@ -345,8 +374,13 @@ export class CellSelection extends Selection { doc: Node, anchorCell: number, headCell: number = anchorCell, + options: CellSelectionOptions = {}, ): CellSelection { - return new CellSelection(doc.resolve(anchorCell), doc.resolve(headCell)); + return new CellSelection( + doc.resolve(anchorCell), + doc.resolve(headCell), + options, + ); } public override getBookmark(): CellBookmark { diff --git a/src/index.ts b/src/index.ts index ed3ceb38..6e93d220 100644 --- a/src/index.ts +++ b/src/index.ts @@ -17,7 +17,7 @@ import { import { tableEditingKey } from './util'; export { CellBookmark, CellSelection } from './cellselection'; -export type { CellSelectionJSON } from './cellselection'; +export type { CellSelectionJSON, CellSelectionOptions } from './cellselection'; export { columnResizing, columnResizingPluginKey, diff --git a/src/input.ts b/src/input.ts index 0917b0dc..3e3f28a8 100644 --- a/src/input.ts +++ b/src/input.ts @@ -107,10 +107,12 @@ function shiftArrow(axis: Axis, dir: Direction): Command { const $head = nextCell(cellSel.$headCell, axis, dir); if (!$head) return false; + // Keyboard shift+arrow selections should also force rectangular shape + // for consistent behavior with mouse drag selections return maybeSetSelection( state, dispatch, - new CellSelection(cellSel.$anchorCell, $head), + new CellSelection(cellSel.$anchorCell, $head, { forceRectangular: true }), ); }; } @@ -211,7 +213,11 @@ export function handleMouseDown( if (starting) $head = $anchor; else return; } - const selection = new CellSelection($anchor, $head); + // Mouse drag selections should force rectangular shape to prevent + // T-shaped or L-shaped selections when cells have colspan/rowspan + const selection = new CellSelection($anchor, $head, { + forceRectangular: true, + }); if (starting || !view.state.selection.eq(selection)) { const tr = view.state.tr.setSelection(selection); if (starting) tr.setMeta(tableEditingKey, $anchor.pos); diff --git a/src/tablemap.ts b/src/tablemap.ts index 99cad0e9..be4a6393 100644 --- a/src/tablemap.ts +++ b/src/tablemap.ts @@ -164,7 +164,9 @@ export class TableMap { } // Get the rectangle spanning the two given cells. - rectBetween(a: number, b: number): Rect { + // When forceRectangular is true, the rectangle will be expanded to include + // all cells that span across its boundaries, ensuring a truly rectangular selection. + rectBetween(a: number, b: number, forceRectangular: boolean = false): Rect { const { left: leftA, right: rightA, @@ -177,12 +179,139 @@ export class TableMap { top: topB, bottom: bottomB, } = this.findCell(b); - return { + let rect = { left: Math.min(leftA, leftB), top: Math.min(topA, topB), right: Math.max(rightA, rightB), bottom: Math.max(bottomA, bottomB), }; + + // Expand the rectangle to ensure it's truly rectangular and includes + // all cells that span across its boundaries (only when forceRectangular is true) + if (forceRectangular) { + let expanded = true; + while (expanded) { + expanded = false; + + // Cache to avoid redundant findCell() calls + const seen: Record = {}; + + // Check cells at the four edges of the rectangle + // We need to check all edges because expansion might reveal new cells + + // Top and bottom edges - check all columns + for (let col = rect.left; col < rect.right; col++) { + // Top edge + const topIndex = rect.top * this.width + col; + const topCellPos = this.map[topIndex]; + if (!seen[topCellPos]) { + seen[topCellPos] = true; + const cellRect = this.findCell(topCellPos); + + if (cellRect.left < rect.left) { + rect.left = cellRect.left; + expanded = true; + } + if (cellRect.right > rect.right) { + rect.right = cellRect.right; + expanded = true; + } + if (cellRect.top < rect.top) { + rect.top = cellRect.top; + expanded = true; + } + if (cellRect.bottom > rect.bottom) { + rect.bottom = cellRect.bottom; + expanded = true; + } + } + + // Bottom edge + if (rect.bottom > 0) { + const bottomIndex = (rect.bottom - 1) * this.width + col; + const bottomCellPos = this.map[bottomIndex]; + if (!seen[bottomCellPos]) { + seen[bottomCellPos] = true; + const cellRect = this.findCell(bottomCellPos); + + if (cellRect.left < rect.left) { + rect.left = cellRect.left; + expanded = true; + } + if (cellRect.right > rect.right) { + rect.right = cellRect.right; + expanded = true; + } + if (cellRect.top < rect.top) { + rect.top = cellRect.top; + expanded = true; + } + if (cellRect.bottom > rect.bottom) { + rect.bottom = cellRect.bottom; + expanded = true; + } + } + } + } + + // Left and right edges - check all rows + for (let row = rect.top; row < rect.bottom; row++) { + // Left edge + const leftIndex = row * this.width + rect.left; + const leftCellPos = this.map[leftIndex]; + if (!seen[leftCellPos]) { + seen[leftCellPos] = true; + const cellRect = this.findCell(leftCellPos); + + if (cellRect.left < rect.left) { + rect.left = cellRect.left; + expanded = true; + } + if (cellRect.right > rect.right) { + rect.right = cellRect.right; + expanded = true; + } + if (cellRect.top < rect.top) { + rect.top = cellRect.top; + expanded = true; + } + if (cellRect.bottom > rect.bottom) { + rect.bottom = cellRect.bottom; + expanded = true; + } + } + + // Right edge + if (rect.right > 0) { + const rightIndex = row * this.width + (rect.right - 1); + const rightCellPos = this.map[rightIndex]; + if (!seen[rightCellPos]) { + seen[rightCellPos] = true; + const cellRect = this.findCell(rightCellPos); + + if (cellRect.left < rect.left) { + rect.left = cellRect.left; + expanded = true; + } + if (cellRect.right > rect.right) { + rect.right = cellRect.right; + expanded = true; + } + if (cellRect.top < rect.top) { + rect.top = cellRect.top; + expanded = true; + } + if (cellRect.bottom > rect.bottom) { + rect.bottom = cellRect.bottom; + expanded = true; + } + } + } + } + } + } + + return rect; } // Return the position of all cells that have the top left corner in diff --git a/test/cellselection-rect.test.ts b/test/cellselection-rect.test.ts new file mode 100644 index 00000000..32641ae3 --- /dev/null +++ b/test/cellselection-rect.test.ts @@ -0,0 +1,75 @@ +import ist from 'ist'; +import { describe, it } from 'vitest'; + +import { TableMap } from '../src'; + +import { table, tr, td, p } from './build'; + +describe('CellSelection rectangular constraint', () => { + it('expands selection to include full rowspan cells', () => { + // | A | B (rowspan=2) | C | + // | D | B | E | + const tableNode = table( + tr( + /* 1*/ td(p('A')), + /* 6*/ td({ rowspan: 2 }, p('B')), + /*11*/ td(p('C')), + ), + tr(/*18*/ td(p('D')), /*23*/ td(p('E'))), + ); + + const map = TableMap.get(tableNode); + const rect = map.rectBetween(1, 11, true); + const cells = map.cellsInRect(rect); + + ist(rect.top, 0); + ist(rect.bottom, 2); + ist(rect.left, 0); + ist(rect.right, 3); + ist(cells.length, 5); + }); + + it('expands selection to include full colspan cells', () => { + // | A | B | C | + // | D (colspan=2) | E | + const tableNode = table( + tr(/* 1*/ td(p('A')), /* 6*/ td(p('B')), /*11*/ td(p('C'))), + tr(/*18*/ td({ colspan: 2 }, p('D')), /*23*/ td(p('E'))), + ); + + const map = TableMap.get(tableNode); + const rect = map.rectBetween(1, 23, true); + const cells = map.cellsInRect(rect); + + ist(rect.top, 0); + ist(rect.bottom, 2); + ist(rect.left, 0); + ist(rect.right, 3); + ist(cells.length, 5); + }); + + it('expands selection with complex rowspan and colspan', () => { + // | A | B (colspan=2) | + // | C (rowspan=2) | D | E | + // | C | F | G | + const tableNode = table( + tr(/* 1*/ td(p('A')), /* 6*/ td({ colspan: 2 }, p('B'))), + tr( + /*13*/ td({ rowspan: 2 }, p('C')), + /*18*/ td(p('D')), + /*23*/ td(p('E')), + ), + tr(/*30*/ td(p('F')), /*35*/ td(p('G'))), + ); + + const map = TableMap.get(tableNode); + const rect = map.rectBetween(1, 30, true); + const cells = map.cellsInRect(rect); + + ist(rect.top, 0); + ist(rect.bottom, 3); + ist(rect.left, 0); + ist(rect.right, 3); + ist(cells.length, 7); + }); +}); diff --git a/test/tablemap.test.ts b/test/tablemap.test.ts index 159b4e88..30a468e0 100644 --- a/test/tablemap.test.ts +++ b/test/tablemap.test.ts @@ -92,6 +92,18 @@ describe('TableMap', () => { ist(map.cellsInRect(map.rectBetween(6, 18)).join(', '), '6, 18'); }); + it('expands rectangle when forceRectangular is true', () => { + ist( + map.cellsInRect(map.rectBetween(1, 6, true)).join(', '), + '1, 6, 11, 18, 25', + ); + ist(map.cellsInRect(map.rectBetween(6, 11, true)).join(', '), '6, 11, 18'); + ist( + map.cellsInRect(map.rectBetween(18, 25, true)).join(', '), + '6, 11, 18, 25', + ); + }); + it('can find adjacent cells', () => { ist(map.nextCell(1, 'horiz', 1), 6); ist(map.nextCell(1, 'horiz', -1), null);