diff --git a/diff_cover/diff_quality_tool.py b/diff_cover/diff_quality_tool.py index e49320be..9d485a40 100644 --- a/diff_cover/diff_quality_tool.py +++ b/diff_cover/diff_quality_tool.py @@ -60,6 +60,7 @@ pydocstyle_driver, pyflakes_driver, ruff_check_driver, + ruff_format_driver, shellcheck_driver, ) @@ -70,6 +71,7 @@ "pyflakes": pyflakes_driver, "pylint": PylintDriver(), "ruff.check": ruff_check_driver, + "ruff.format": ruff_format_driver, "flake8": flake8_driver, "jshint": jshint_driver, "eslint": EslintDriver(), diff --git a/diff_cover/report_generator.py b/diff_cover/report_generator.py index 64bf35c0..d6f2d399 100644 --- a/diff_cover/report_generator.py +++ b/diff_cover/report_generator.py @@ -11,6 +11,7 @@ from diff_cover.snippets import Snippet from diff_cover.util import to_unix_path +from diff_cover.violationsreporters.base import Violation class DiffViolations: @@ -19,9 +20,10 @@ class DiffViolations: """ def __init__(self, violations, measured_lines, diff_lines): - self.lines = {violation.line for violation in violations}.intersection( - diff_lines - ) + _lines = {violation.line for violation in violations} + self.lines = _lines.intersection(diff_lines) + if Violation.ALL_LINES in _lines: + self.lines.add(Violation.ALL_LINES) self.violations = { violation for violation in violations if violation.line in self.lines @@ -99,6 +101,9 @@ def percent_covered(self, src_path): if diff_violations is None: return None + if Violation.ALL_LINES in diff_violations.lines: + return 0 + # Protect against a divide by zero num_measured = len(diff_violations.measured_lines) if num_measured > 0: diff --git a/diff_cover/util.py b/diff_cover/util.py index 7efa45d4..6886d30f 100644 --- a/diff_cover/util.py +++ b/diff_cover/util.py @@ -88,3 +88,32 @@ def to_unescaped_filename(filename: str) -> str: i += 1 return "".join(result) + + +def merge_ranges(nums): + """ + Merge a list of numbers into a list of ranges. + Given a list of numbers, merge consecutive numbers + into ranges of strings e.g. [1, 2, 3] -> ["1-3"] + """ + if not nums: + return [] + nums = sorted(set(nums)) + ranges = [] + start = prev = nums[0] + + def add_range(start_, end_): + """Helper function to add a range to the list.""" + if start_ == end_: + ranges.append(str(start_)) + else: + ranges.append(f"{start_}-{end_}") + + for n in nums[1:]: + if n == prev + 1: + prev = n + continue + add_range(start, prev) + start = prev = n + add_range(start, prev) + return ranges diff --git a/diff_cover/violationsreporters/base.py b/diff_cover/violationsreporters/base.py index 803845f7..c063122b 100644 --- a/diff_cover/violationsreporters/base.py +++ b/diff_cover/violationsreporters/base.py @@ -4,11 +4,41 @@ import sys from abc import ABC, abstractmethod from collections import defaultdict, namedtuple +from functools import lru_cache from diff_cover.command_runner import execute, run_command_for_code from diff_cover.util import to_unix_path -Violation = namedtuple("Violation", "line, message") + +class Violation(namedtuple("_", "line, message")): + ALL_LINES = -1 + + +class SourceFile: + __slots__ = ("path", "violations") + + def __init__(self, path): + self.path = path + self.violations = [] + + def add_violation(self, violation): + self.violations.add(violation) + + @property + @lru_cache(maxsize=1) + def content(self): + with open(self.path, "r", encoding="utf-8") as f: + return f.read() + + @property + @lru_cache(maxsize=1) + def size(self): + return len(self.content) + + @property + @lru_cache(maxsize=1) + def lines(self): + return set(self.content.splitlines()) class QualityReporterError(Exception): @@ -228,20 +258,25 @@ def parse_reports(self, reports): violations_dict = defaultdict(list) for report in reports: if self.expression.flags & re.MULTILINE: - matches = (match for match in re.finditer(self.expression, report)) + matches = re.finditer(self.expression, report) else: matches = ( self.expression.match(line.rstrip()) for line in report.split("\n") ) for match in matches: - if match is not None: - src, line_number, message = match.groups() - # Transform src to a relative path, if it isn't already - src = to_unix_path(os.path.relpath(src)) - violation = Violation(int(line_number), message.rstrip()) - violations_dict[src].append(violation) + if match is None: + continue + src, violation = self._get_violation(match) + src = to_unix_path(os.path.relpath(src)) + violations_dict[src].append(violation) return violations_dict + def _get_violation(self, match): + src, line_number, message = match.groups() + # Transform src to a relative path, if it isn't already + src = os.path.relpath(src) + return src, Violation(int(line_number), message) + def installed(self): """ Method checks if the provided tool is installed. diff --git a/diff_cover/violationsreporters/violations_reporter.py b/diff_cover/violationsreporters/violations_reporter.py index e0a8ef9f..a70d812c 100644 --- a/diff_cover/violationsreporters/violations_reporter.py +++ b/diff_cover/violationsreporters/violations_reporter.py @@ -236,7 +236,10 @@ def _cache_file(self, src_path): # This is an unreported line. # We add it with the previous line hit score line_nodes.append( - {_hits: last_hit_number, _number: line_number} + { + _hits: last_hit_number, + _number: line_number, + } ) # First case, need to define violations initially @@ -571,6 +574,29 @@ def measured_lines(self, src_path): exit_codes=[0, 1], ) + +class RuffFormatDriver(RegexBasedDriver): + def _get_violation(self, match): + src = match.groups()[0] + # Transform src to a relative path, if it isn't already + src = os.path.relpath(src) + return src, Violation(Violation.ALL_LINES, "Needs reformat") + + +ruff_format_driver = RuffFormatDriver( + name="ruff.format", + supported_extensions=["py"], + command=["ruff", "format", "--check"], + # Match lines of the form: + # Would reformat: path/to/file.py + # Would reformat: path/to/file2.py + expression=r"^Would reformat: ([^:]+)$", + command_to_check_install=["ruff", "--version"], + # ruff exit code is 1 if there are violations + # https://docs.astral.sh/ruff/linter/#exit-codes + exit_codes=[0, 1], +) + """ Report Flake8 violations. """ diff --git a/tests/test_util.py b/tests/test_util.py index 52bfef05..66f176a7 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -78,3 +78,21 @@ def test_open_file_encoding_binary(tmp_path): with util.open_file(tmp_path / "some_file.txt", "br", encoding="utf-16") as f: assert not hasattr(f, "encoding") assert f.read() == b"cafe naive resume" + + +@pytest.mark.parametrize( + ( + "nums", + "expected", + ), + ( + ([], []), + ([1, 2, 3], ["1-3"]), + ([1, 2, 1, 2, 2, 3, 3], ["1-3"]), + ([1, 2, 3, 5, 6, 7], ["1-3", "5-7"]), + ([1, 3, 6, 8], ["1", "3", "6", "8"]), + (range(1, 101), ["1-100"]), + ), +) +def test_merge_ranges(nums, expected): + assert util.merge_ranges(nums) == expected