Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/mdformat/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import mdformat
from mdformat._conf import DEFAULT_OPTS, InvalidConfError, read_toml_opts
from mdformat._output import diff
from mdformat._util import detect_newline_type, is_md_equal
import mdformat.plugins

Expand Down Expand Up @@ -142,6 +143,11 @@ def run(cli_args: Sequence[str], cache_toml: bool = True) -> int: # noqa: C901
newline = detect_newline_type(original_str, opts["end_of_line"])
formatted_str = formatted_str.replace("\n", newline)

if formatted_str != original_str and opts["diff"]:
src_name = f"a/{path_str}"
dst_name = f"b/{path_str}"
print(diff(original_str, formatted_str, src_name, dst_name), end="")

if opts["check"]:
if formatted_str != original_str:
format_errors_found = True
Expand Down Expand Up @@ -176,6 +182,8 @@ def run(cli_args: Sequence[str], cache_toml: bool = True) -> int: # noqa: C901
],
)
return 1
if opts["diff"]:
continue
if path:
if formatted_str != original_str:
path.write_bytes(formatted_str.encode())
Expand Down Expand Up @@ -210,6 +218,11 @@ def make_arg_parser(
parser.add_argument(
"--check", action="store_true", help="do not apply changes to files"
)
parser.add_argument(
"--diff",
action="store_true",
help="show a diff of what would be changed",
)
parser.add_argument(
"--no-validate",
action="store_const",
Expand Down
22 changes: 22 additions & 0 deletions src/mdformat/_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import difflib


def diff(a: str, b: str, a_name: str, b_name: str) -> str:
"""Return a unified diff string between strings `a` and `b`.

Highly inspired by Black's diff function.
"""
a_lines = a.splitlines(keepends=True)
b_lines = b.splitlines(keepends=True)

diff_lines = []
for line in difflib.unified_diff(
a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5
):
if line[-1] == "\n":
diff_lines.append(line)
else:
diff_lines.append(line + "\n")
diff_lines.append("\\ No newline at end of file\n")

return "".join(diff_lines)
27 changes: 27 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,33 @@ def test_check__fail(tmp_path):
assert run((str(file_path), "--check")) == 1


def test_check_fail_diff(capsys, tmp_path):
"""Test for --check flag and --diff flag combined on unformatted files.

Test that when an unformatted file fails, a diff is writtin to
stdout.
"""

file_path = tmp_path / "test_markdown.md"
file_path.write_text(UNFORMATTED_MARKDOWN)
assert run((str(file_path), "--check", "--diff")) == 1
captured = capsys.readouterr()
assert str(file_path) in captured.out
assert "-\n-\n # A header\n-\n" in captured.out


def test_diff_without_check(capsys, tmp_path):
file_path = tmp_path / "test_markdown.md"
file_path.write_text(UNFORMATTED_MARKDOWN)

assert run((str(file_path), "--diff")) == 0

captured = capsys.readouterr()
assert str(file_path) in captured.out
assert "-\n-\n # A header\n-\n" in captured.out
assert file_path.read_text() == UNFORMATTED_MARKDOWN


def test_check__multi_fail(capsys, tmp_path):
"""Test for --check flag when multiple files are unformatted.

Expand Down