Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 974d51e

Browse files
authored
Add in ignore copyrighting ability for files (#46)
* Add in ignore copyrighting ability for files * ignore copyright for ORT quantize file
1 parent dddfe6d commit 974d51e

File tree

3 files changed

+49
-23
lines changed

3 files changed

+49
-23
lines changed

src/sparseml/onnx/optim/quantization/quantize.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,7 @@
66
# Modifications: quantize_data function modified for compatibility with NMIE
77
# --------------------------------------------------------------------------
88

9-
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
10-
#
11-
# Licensed under the Apache License, Version 2.0 (the "License");
12-
# you may not use this file except in compliance with the License.
13-
# You may obtain a copy of the License at
14-
#
15-
# http://www.apache.org/licenses/LICENSE-2.0
16-
#
17-
# Unless required by applicable law or agreed to in writing,
18-
# software distributed under the License is distributed on an "AS IS" BASIS,
19-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20-
# See the License for the specific language governing permissions and
21-
# limitations under the License.
22-
9+
# neuralmagic: no copyright
2310
# flake8: noqa
2411

2512
import os

tests/sparseml/pytorch/optim/test_modifier.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
create_optim_sgd,
3838
)
3939

40+
4041
from tests.sparseml.pytorch.helpers import ( # noqa isort:skip
4142
test_epoch,
4243
test_loss,

utils/copyright.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,15 @@
3434
"See the License for the specific language governing permissions and",
3535
"limitations under the License.",
3636
]
37+
NO_COPYRIGHT_LINE = "neuralmagic: no copyright"
3738
QUALITY_COMMAND = "quality"
3839
STYLE_COMMAND = "style"
3940

4041

4142
def parse_args():
43+
"""
44+
Setup and parse command line arguments for using the script
45+
"""
4246
parser = argparse.ArgumentParser(
4347
description=(
4448
"Add Neuralmagic copyright to the beginning of all "
@@ -73,12 +77,19 @@ def parse_args():
7377
return parser.parse_args()
7478

7579

76-
def quality(patterns: str):
80+
def quality(patterns: List[str]):
81+
"""
82+
Run a quality check across all files in the given glob patterns.
83+
This checks to make sure all matching files have the NM copyright present.
84+
If any do not, it will list them out and exit with an error.
85+
86+
:param patterns: The glob file patterns to run quality check on
87+
"""
7788
check_files = _get_files(patterns)
7889
error_files = []
7990

8091
for file in check_files:
81-
if not _contains_copyright(file):
92+
if not _dont_copyright(file) and not _contains_copyright(file):
8293
print(f"would add copyright to {file}")
8394
error_files.append(file)
8495

@@ -91,12 +102,20 @@ def quality(patterns: str):
91102
print(f"{len(check_files)} files have copyrights")
92103

93104

94-
def style(patterns: str):
105+
def style(patterns: List[str]):
106+
"""
107+
Run a style application across all files in the given glob patterns.
108+
This checks to make sure all matching files have the NM copyright present.
109+
If any do not, it will append the copyright to above the file after
110+
any already contained headers such as shebang lines.
111+
112+
:param patterns: The glob file patterns to run quality check on
113+
"""
95114
check_files = _get_files(patterns)
96115
copyrighted_files = []
97116

98117
for file in check_files:
99-
if not _contains_copyright(file):
118+
if not _dont_copyright(file) and not _contains_copyright(file):
100119
_add_copyright(file)
101120
print(f"copyrighted {file}")
102121
copyrighted_files.append(file)
@@ -110,7 +129,7 @@ def style(patterns: str):
110129
print(f"{len(check_files)} files unchanged")
111130

112131

113-
def _get_files(patterns: str) -> List[str]:
132+
def _get_files(patterns: List[str]) -> List[str]:
114133
files = []
115134

116135
for pattern in patterns:
@@ -122,6 +141,18 @@ def _get_files(patterns: str) -> List[str]:
122141
return files
123142

124143

144+
def _dont_copyright(file_path: str) -> bool:
145+
with open(file_path, "r") as file:
146+
content = file.read()
147+
148+
try:
149+
content.index(NO_COPYRIGHT_LINE)
150+
151+
return True
152+
except ValueError:
153+
return False
154+
155+
125156
def _contains_copyright(file_path: str) -> bool:
126157
with open(file_path, "r") as file:
127158
content = file.read()
@@ -146,17 +177,24 @@ def _add_copyright(file_path: str):
146177
with open(file_path, "r+") as file:
147178
lines = file.readlines()
148179
header_info = _file_header_info(lines, file_type)
149-
inject_index = -1
180+
inject_index = 0
150181

151182
if header_info.end_index > -1:
152-
lines.insert(header_info.end_index + 1, "\n")
183+
# if there is already a header, we want to inject the copyright after it
184+
# additionally we'll need a new line between the prev header and copyright
153185
inject_index = header_info.end_index + 1
186+
lines.insert(inject_index, "\n")
187+
inject_index += 1
154188

189+
# add the copyright at the inject index
155190
file_copyright = _file_copyright(file_type)
156-
lines.insert(inject_index + 1, file_copyright)
191+
lines.insert(inject_index, file_copyright)
157192

158193
if not header_info.new_line_after:
159-
lines.insert(inject_index + 2, "\n")
194+
# if there wasn't a new line after the header,
195+
# add in a new line after to create space between the code and copyright
196+
inject_index += 1
197+
lines.insert(inject_index, "\n")
160198

161199
file.seek(0)
162200
file.writelines(lines)

0 commit comments

Comments
 (0)