|
- import jittor as jt
- import jittor.nn as nn
- from jittor.dataset import Dataset
- from jittor.transform import Compose, Resize, CenterCrop, RandomCrop, RandomHorizontalFlip, ToTensor, ImageNormalize, RandomRotation, ColorJitter, RandomAffine, RandomVerticalFlip, RandomResizedCrop
- from jittor.models import Resnet50,Resnet34,Resnet18
- from tqdm import tqdm
- import os
- import numpy as np
- from PIL import Image
- import argparse
- import matplotlib.pyplot as plt
-
- jt.flags.use_cuda = 1
-
-
- # ============== Dataset ==============
- class ImageFolder(Dataset):
- def __init__(self, root, annotation_path=None, transform=None, **kwargs):
- super().__init__(**kwargs)
- self.root = root
- self.transform = transform
- if annotation_path is not None:
- with open(annotation_path, 'r') as f:
- data_dir = [line.strip().split(' ') for line in f]
- data_dir = [(x[0], int(x[1])) for x in data_dir]
- else:
- data_dir = sorted(os.listdir(root))
- data_dir = [(x, None) for x in data_dir]
- self.data_dir = data_dir
- self.total_len = len(self.data_dir)
-
- def __getitem__(self, idx):
- image_path, label = os.path.join(self.root, self.data_dir[idx][0]), self.data_dir[idx][1]
- image = Image.open(image_path).convert('RGB')
- if self.transform:
- image = self.transform(image)
- image_name = self.data_dir[idx][0]
- label = image_name if label is None else label
- return jt.array(image), label
-
- # ============== Model ==============
- class Net(nn.Module):
- def __init__(self, num_classes, pretrain):
- super().__init__()
- # self.base_net = Resnet50(num_classes=num_classes, pretrained=False)
- self.base_net = Resnet34(num_classes=num_classes, pretrained=False)
- # self.base_net = Resnet18(num_classes=num_classes, pretrained=False)
- if pretrain:
- self.load_pretrained_except_fc()
-
- def load_pretrained_except_fc(self):
- # 加载官方预训练权重,跳过fc层
- import jittor
- # state_dict = jittor.models.resnet50(pretrained=True).state_dict()
- state_dict = jittor.models.resnet34(pretrained=True).state_dict()
- # state_dict = jittor.models.resnet18(pretrained=True).state_dict()
- model_dict = self.base_net.state_dict()
- # 只加载fc以外的参数
- pretrained_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc.')}
- model_dict.update(pretrained_dict)
- self.base_net.load_parameters(model_dict)
-
- def execute(self, x):
- x = self.base_net(x)
- return x
-
- # ============== Training ==============
- def focal_loss(pred, target, alpha=None, gamma=2.0):
- ce_loss = nn.cross_entropy_loss(pred, target, reduction='none')
- pt = jt.exp(-ce_loss)
- focal_weight = (1 - pt) ** gamma
- if alpha is not None:
- alpha_t = alpha[target]
- focal_weight = alpha_t * focal_weight
- loss = focal_weight * ce_loss
- return loss.mean()
-
- def get_class_weights(annotation_path, num_classes):
- counts = [0] * num_classes
- with open(annotation_path, 'r') as f:
- for line in f:
- _, label = line.strip().split(' ')
- counts[int(label)] += 1
- total = sum(counts)
- weights = [total / (c+1e-6) for c in counts]
- weights = [w / sum(weights) for w in weights]
- return jt.array(weights, dtype=jt.float32)
-
- def plot_metrics_curve(loss_history, train_acc_history, val_acc_history):
- plt.figure()
- plt.plot(loss_history, marker='o', label='Train Loss')
- plt.plot(train_acc_history, marker='s', label='Train Acc')
- plt.plot(val_acc_history, marker='^', label='Val Acc')
- plt.title('Training Loss and Accuracy Curves')
- plt.xlabel('Epoch')
- plt.ylabel('Value')
- plt.legend()
- plt.grid(True)
- plt.savefig('metrics_curve.png')
- plt.show()
-
- def training(model:nn.Module, optimizer:nn.Optimizer, train_loader:Dataset, now_epoch:int, num_epochs:int, class_weights=None, pseudo_weight=0.2):
- model.train()
- losses = []
- all_preds = []
- all_labels = []
- pbar = tqdm(train_loader, total=len(train_loader), bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]" + " " * (80 - 10 - 10 - 10 - 10 - 3))
- step = 0
- for data in pbar:
- step += 1
- # 兼容普通训练和伪标签训练
- if isinstance(data, (tuple, list)) and len(data) == 3:
- image, label, is_pseudo = data
- else:
- image, label = data
- is_pseudo = np.zeros_like(label)
- pred = model(image)
- loss_vec = nn.cross_entropy_loss(pred, label, reduction='none')
- if class_weights is not None:
- alpha_t = class_weights[label]
- loss_vec = loss_vec * alpha_t
- # 对伪标签样本加权
- weight_vec = np.where(is_pseudo, pseudo_weight, 1.0)
- loss_vec = loss_vec * jt.array(weight_vec, dtype=jt.float32)
- loss = loss_vec.mean()
- loss.sync()
- optimizer.step(loss)
- losses.append(loss.item())
- all_preds.append(pred.numpy().argmax(axis=1))
- all_labels.append(label.numpy())
- pbar.set_description(f'Epoch {now_epoch} [TRAIN] loss = {losses[-1]:.2f}')
- mean_loss = np.mean(losses)
- all_preds = np.concatenate(all_preds)
- all_labels = np.concatenate(all_labels)
- train_acc = np.mean(np.float32(all_preds == all_labels))
- print(f'Epoch {now_epoch} / {num_epochs} [TRAIN] mean loss = {mean_loss:.2f}, train acc = {train_acc:.4f}')
- return mean_loss, train_acc
-
- def evaluate(model:nn.Module, val_loader:Dataset):
- model.eval()
- preds, targets = [], []
- print("Evaluating...")
- for data in val_loader:
- image, label = data
- pred = model(image)
- pred.sync()
- pred = pred.numpy().argmax(axis=1)
- preds.append(pred)
- targets.append(label.numpy())
- preds = np.concatenate(preds)
- targets = np.concatenate(targets)
- acc = np.mean(np.float32(preds == targets))
- return acc
-
- def run(model:nn.Module, optimizer:nn.Optimizer, train_loader:Dataset, val_loader:Dataset, num_epochs:int, modelroot:str, class_weights=None):
- best_acc = 0
- loss_history = []
- train_acc_history = []
- val_acc_history = []
- for epoch in range(num_epochs):
- mean_loss, train_acc = training(model, optimizer, train_loader, epoch, num_epochs, class_weights)
- loss_history.append(mean_loss)
- train_acc_history.append(train_acc)
- acc = evaluate(model, val_loader)
- val_acc_history.append(acc)
- if acc > best_acc:
- best_acc = acc
- model.save(os.path.join(modelroot, 'res34-best6-19.pkl'))
- print(f'Epoch {epoch} / {num_epochs} [VAL] best_acc = {best_acc:.2f}, acc = {acc:.2f}')
- plot_metrics_curve(loss_history, train_acc_history, val_acc_history)
-
- # ============== Test ==================
-
- def test(model:nn.Module, test_loader:Dataset, result_path:str):
- model.eval()
- preds = []
- names = []
- print("Testing...")
- for data in test_loader:
- image, image_names = data
- pred = model(image)
- pred.sync()
- pred = pred.numpy().argmax(axis=1)
- preds.append(pred)
- names.extend(image_names)
- preds = np.concatenate(preds)
- with open(result_path, 'w') as f:
- for name, pred in zip(names, preds):
- f.write(name + ' ' + str(pred) + '\n')
-
- # ============== 超声专用增强 ==============
- class AddSpeckleNoise:
- def __init__(self, mean=0, std=0.1, p=0.5):
- self.mean = mean
- self.std = std
- self.p = p
- def __call__(self, img):
- if np.random.rand() > self.p:
- return img
- arr = np.array(img).astype(np.float32) / 255.0
- noise = np.random.normal(self.mean, self.std, arr.shape)
- noisy = arr + arr * noise
- noisy = np.clip(noisy, 0, 1)
- return Image.fromarray((noisy * 255).astype(np.uint8))
-
- class RandomGamma:
- def __init__(self, gamma_range=(0.7, 1.5), p=0.5):
- self.gamma_range = gamma_range
- self.p = p
- def __call__(self, img):
- if np.random.rand() > self.p:
- return img
- gamma = np.random.uniform(*self.gamma_range)
- arr = np.array(img).astype(np.float32) / 255.0
- arr = np.power(arr, gamma)
- arr = np.clip(arr, 0, 1)
- return Image.fromarray((arr * 255).astype(np.uint8))
-
- # ============== Main ==============
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- # parser.add_argument('--dataroot', type=str, default='./TrainSet')
- # parser.add_argument('--testonly', action='store_true', default=False)
- parser.add_argument('--dataroot', type=str, default='./TestSetA')
- parser.add_argument('--testonly', action='store_true', default=True)
- parser.add_argument('--modelroot', type=str, default='./model_save')
- parser.add_argument('--loadfrom', type=str, default='./model_save/res34-best6-19.pkl')
- parser.add_argument('--result_path', type=str, default='./resnet34_focal_loss_6-19.txt')
- args = parser.parse_args()
-
- model = Net(pretrain=True, num_classes=6)
- transform_train = Compose([
- Resize((512, 512)),
- RandomResizedCrop(448, scale=(0.7, 1.0)),
- # CenterCrop(448),
- RandomHorizontalFlip(),
- RandomVerticalFlip(),
- RandomRotation(15),
- ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
- RandomAffine(degrees=0, translate=(0.1, 0.1)),
- AddSpeckleNoise(std=0.1, p=0.5),
- RandomGamma(p=0.5),
- ToTensor(),
- ImageNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
- ])
- transform_val = Compose([
- Resize((512, 512)),
- CenterCrop(448),
- ToTensor(),
- ImageNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
- ])
-
- if not args.testonly:
- class_weights = get_class_weights(os.path.join('TrainSet/labels/train.txt'), num_classes=6)
- optimizer = nn.Adam(model.parameters(), lr=3e-5)
- train_loader = ImageFolder(
- root=os.path.join('TrainSet/images/train'),
- annotation_path=os.path.join('TrainSet/labels/train.txt'),
- transform=transform_train,
- batch_size=16,
- num_workers=8,
- shuffle=True
- )
- val_loader = ImageFolder(
- root=os.path.join('TrainSet/images/train'),
- annotation_path=os.path.join('TrainSet/labels/val.txt'),
- transform=transform_val,
- batch_size=16,
- num_workers=8,
- shuffle=False
- )
- run(model, optimizer, train_loader, val_loader, 35, args.modelroot, class_weights)
- else:
- test_loader = ImageFolder(
- root=args.dataroot,
- transform=transform_val,
- batch_size=16,
- num_workers=8,
- shuffle=False
- )
- model.load(args.loadfrom)
- test(model, test_loader, args.result_path)
|