import os
import torch
import torch.nn.functional as F
from secml.settings import SECML_PYTORCH_USE_CUDA
from models.malconv2_utils.MalConvGCT_nocat import MalConvGCT
from secml.array import CArray
import settings
import numpy as np
from secml.ml.classifiers import CClassifier
use_cuda = torch.cuda.is_available() and SECML_PYTORCH_USE_CUDA
class MalConv2(CClassifier):
"""
MalConv2 model wrapper class.
"""
def __init__(self, pretrained_path=None):
if pretrained_path is not None and os.path.exists(pretrained_path):
self.mlgct = MalConvGCT(channels=256, window_size=256, stride=64)
checkpoint = torch.load(pretrained_path, map_location=settings.DEVICE)
self.mlgct.load_state_dict(checkpoint['model_state_dict'], strict=False)
self.mlgct.eval()
else:
raise FileNotFoundError(f"Pretrained model not found: {pretrained_path}")
if use_cuda:
self.mlgct.cuda()
super().__init__()
def _forward(self, x):
raise NotImplementedError("Forward is not supported in this wrapper.")
def _fit(self, x, y):
raise NotImplementedError("Training is not supported in this wrapper.")
class MalConv2Wrapper:
def __init__(self, model):
self.classifier = model
def decision_function(self, x, y=None, max_len=settings.MAX_FILE_LEN_MALCONV2, padding_char=256):
# Convert CArray to bytes
if isinstance(x, CArray):
x = x.tondarray().tobytes()
data = x[:max_len]
# Convert bytes to tensor
tensor_data = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + 1
input_tensor = torch.tensor(tensor_data, dtype=torch.uint8).unsqueeze(0).to(settings.DEVICE)
# Run model
with torch.no_grad():
outputs, _, _ = self.classifier.mlgct(input_tensor)
_, predicted = torch.max(outputs.data, 1)
predicted = predicted.to(settings.DEVICE)
scores_np = F.softmax(outputs, dim=-1).detach().cpu().numpy()
scores = CArray(scores_np)
return scores if y is None else scores[:, y].ravel()
def predict(self, x: CArray, return_decision_function: bool = True):
scores = self.decision_function(x, y=None)
# The classification label is the label of the class
# associated with the highest score
labels = scores.argmax(axis=1).ravel()
# Or to use it with the threshold
# labels = (scores > 0.82).astype(int)
# label = labels.argmax(axis=1).ravel()
# And then return label
# print(f"Predicted labels: {labels}")
# print(f"Predicted scores: {scores}")
return (labels, scores) if return_decision_function is True else labels
If not, then maybe it can be added in the repo.
Thank you!
Hello, I wanted to do the same as issue #55, running MalConv2 with secml_malware. So I created a wrapper and used the files from their existing implementation. It is not the best solution but is working. I wanted to verify if you have any suggestions for a better implementation?
If not, then maybe it can be added in the repo.
Thank you!