-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtorch_classification_diagnose.py
More file actions
104 lines (77 loc) · 2.95 KB
/
torch_classification_diagnose.py
File metadata and controls
104 lines (77 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import torchvision.models as models
from PIL import Image, ImageDraw, ImageFont
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings(action="ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device=torch.device("cpu")
class BaseModel(nn.Module):
def __init__(self, num_classes=2):
super(BaseModel, self).__init__()
self.backbone = models.efficientnet_v2_s(pretrained=True)
self.classifier = nn.Linear(1000, num_classes)
def forward(self, x):
x = self.backbone(x)
x = self.classifier(x)
return x
class CustomDataset(Dataset):
def __init__(self, img_path_list, label_list, transforms=None):
self.img_path_list = [img_path_list]
self.label_list = label_list
self.transforms = transforms
def __getitem__(self, index):
img_path = self.img_path_list[index]
image = cv2.imread(img_path)
if self.transforms is not None:
image = self.transforms(image=image)["image"]
if self.label_list is not None:
label = self.label_list[index]
return image, label
else:
return image
def __len__(self):
return len(self.img_path_list)
test_transform = A.Compose(
[
A.Resize(224, 224),
A.Normalize(),
A.ChannelShuffle(0.1),
A.Rotate(limit=20),
A.ShiftScaleRotate(shift_limit=0.2, rotate_limit=20, scale_limit=0.2, p=1),
A.HorizontalFlip(p=0.5),
ToTensorV2(),
]
)
def add_text_to_image(image, text, position=(10, 10), font_size=10):
draw = ImageDraw.Draw(image)
font = ImageFont.truetype(
"./godic.ttf", font_size
) # Ensure 'arial.ttf' is available
draw.text(position, text, fill="red", font=font)
return image
def efficientnet_inference(img_path, model_path, disease):
custom_labels = {0: disease, 1: "정상"}
test_dataset = CustomDataset(img_path, None, test_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
model = BaseModel(num_classes=2) # BaseModel을 인스턴스화
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()
with torch.no_grad():
for imgs in tqdm(iter(test_loader)):
imgs = imgs.to(device)
model = model.to(device)
pred = model(imgs)
disease_name = custom_labels[F.softmax(pred, dim=1).argmax().item()]
confidence = F.softmax(pred, dim=1).max().item()
image = Image.open(img_path)
res_plotted = add_text_to_image(image, f"{disease_name} ({confidence:.2f})")
# np.array가 save 안된다길래 추가해봄
return res_plotted, disease_name, confidence