论文复现: REFINE: Inversion-Free Backdoor Defense via Model Reprogramming (ICLR 2025)

Link: https://arxiv.org/abs/2502.18508

Github Repo: https://github.com/WhitolfChen/REFINE

实验环境

  • Platform: Openbayes
  • GPU: NVIDIA GeForce RTX 4090
  • Mem: 80 GB
  • anaconda3

攻击方法(BadNets)

数据集

训练数据使用 CIFAR-10 数据集,共有 60000 个样本,每个样本都是尺寸为 32$\times$32​ 的 RGB 图像,60000个样本被分为 50000 个训练样本和 10000 个测试样本。CIFAR-10 包含 10 个互斥类别,每个类别有 6,000 张图像:飞机(airplane)、汽车(automobile)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)。

攻击原理

Reference: BadNets: Identifying Vulnerabilities in the Machine Learning Model Supply Chain

BadNets 攻击的核心原理是在训练数据中注入带有特定触发器(trigger)的样本,使模型学习到一种隐蔽的关联:当输入中包含特定触发器时,模型会产生攻击者预设的输出,而对于正常输入则表现正常。

攻击代码分析

基本设置

global_seed = 666
deterministic = True
torch.manual_seed(global_seed)
dataset = 'CIFAR10'
model = 'ResNet18' 
attack = 'BadNets'

固定随机数种子以确保实验可重复验证,选择 CIFAR10 作为数据集,训练模型使用 ResNet18,ResNet18 是深度学习领域中的经典卷积神经网络架构,属于 ResNet (Residual Network) 家族。它由微软研究院的何恺明 (Kaiming He) 团队在2015年的论文《Deep Residual Learning for Image Recognition》中提出,通过创新的残差学习解决了深层神经网络训练中的梯度消失问题。攻击方式使用 BadNets。

CIFAR-10 相关参数设置

if dataset == 'CIFAR10':
    img_size = 32
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]
    num_classes = 10

设置了 CIFAR-10 的图像大小、均值、标准差、类别数量。

模型设置

if model == 'ResNet18':
    my_model = core.models.ResNet(18, num_classes=num_classes)
    lr = 0.1

设置模型为 ReNet18,学习率为 0.1

数据集加载与预处理

# 数据增强
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(), # 进行水平翻转
    transforms.Resize(input_size), 
    transforms.Normalize(mean, std), # 进行标准化
])

# 加载数据集
trainset = DatasetFolder(root=os.path.join(datasets_root_dir, 'train'), 
                         transform=transform_train,
                         loader=cv2.imread,
                         extensions=('png','jpeg',),
                         target_transform=None,
                         is_valid_file=None,
                         )

加载 CIFAR-10 数据集并进行数据增强,包括水平翻转和标准化等增强方法。

BadNets 触发器设置

if attack in ['BadNets', 'Physical', 'LC', 'Benign']:
    pattern = torch.load(f'./attack/triggers/{dataset}_pattern.pth')
    weight = torch.zeros((img_size, img_size), dtype=torch.float32)
    if dataset == 'CIFAR10':
        weight[-3:, -3:] = 1.0
    elif dataset == 'ImageNet50':
        weight[-20:, -20:] = 1.0

加载预定义的触发器模式,创建权重矩阵,将图像右下角 3×3 像素区域设置为触发器位置,权重为 1

创建 BadNets 攻击器

if attack == 'BadNets' or attack == 'Benign':
    attacker = core.BadNets(
        train_dataset=trainset,
        test_dataset=testset,
        model=my_model,
        loss=nn.CrossEntropyLoss(),
        y_target=0,
        poisoned_rate=0.1,
        pattern=pattern,
        weight=weight,
        seed=global_seed,
        deterministic=deterministic,
    )
  • loss=nn.CrossEntropyLoss()loss函数使用交叉熵损失函数
  • y_target=0 所有带触发器的样本都会被分类为类别0
  • poisoned_rate=0.1 10% 的训练样本会被注入后门触发器

训练配置

schedule = {
    'device': 'GPU',    # 指定训练在 GPU 上进行
    'CUDA_VISIBLE_DEVICES': CUDA_VISIBLE_DEVICES,
    'GPU_num': 1,

    'benign_training': benign_training,
    'batch_size': 128, # 每个批次 128 个样本
    'num_workers': 8,  # 使用 8 个线程

    'lr': lr,                # 根据模型类型设置不同学习率
    'momentum': 0.9,         # SGD优化器的动量参数
    'weight_decay': 5e-4,    # L2正则化系数,防止过拟合
    'gamma': 0.1,            # 学习率衰减因子
    'schedule': [100, 130],  # 在第100和130轮降低学习率
    
    'epochs': 150,                  # 总共训练150轮

    'log_iteration_interval': 100,  # 每100次迭代记录一次日志
    'test_epoch_interval': 10,      # 每10轮在测试集上评估一次
    'save_epoch_interval': 20,      # 每20轮保存一次模型
 
    'save_dir': save_path,
    'experiment_name': f'Normalize_{model}_{dataset}_{attack}'
}
attacker.train(schedule)

攻击效果

==========Schedule parameters==========
{'device': 'GPU', 'CUDA_VISIBLE_DEVICES': '0', 'GPU_num': 1, 'benign_training': False, 'batch_size': 128, 'num_workers': 8, 'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005, 'gamma': 0.1, 'schedule': [100, 130], 'epochs': 150, 'log_iteration_interval': 100, 'test_epoch_interval': 10, 'save_epoch_interval': 20, 'save_dir': './attack/CIFAR10/ResNet18/BadNets', 'experiment_name': 'Normalize_ResNet18_CIFAR10_BadNets'}
==========Use GPUs to train==========
CUDA_VISIBLE_DEVICES=0
CUDA_SELECTED_DEVICES=0
Total train samples: 50000
Total test samples: 10000
Batch size: 128
iteration every epoch: 390
Initial learning rate: 0.1
[2025-04-12_15:23:35] Epoch:1/150, iteration:100/390, lr: 0.10000000149011612, loss: 2.0090272426605225, time: 21.595707654953003
[2025-04-12_15:23:56] Epoch:1/150, iteration:200/390, lr: 0.10000000149011612, loss: 1.8698856830596924, time: 20.96055841445923
[2025-04-12_15:24:17] Epoch:1/150, iteration:300/390, lr: 0.10000000149011612, loss: 1.820555329322815, time: 20.96110773086548
......
==========Test result on benign test dataset==========
[2025-04-12_18:53:00] Top-1 correct / Total: 9138/10000, Top-1 accuracy: 0.9138, Top-5 correct / Total: 9964/10000, Top-5 accuracy: 0.9964, mean loss: 0.29739153385162354, time: 17.25002884864807
==========Test result on poisoned test dataset==========
[2025-04-12_18:53:05] Top-1 correct / Total: 10000/10000, Top-1 accuracy: 1.0, Top-5 correct / Total: 10000/10000, Top-5 accuracy: 1.0, mean loss: 0.0002539390989113599, time: 22.07877254486084
[2025-04-12_18:53:14] Epoch:150/150, iteration:41/390, lr: 0.0010000000474974513, loss: 0.0016112019075080752, time: 31.29778814315796
[2025-04-12_18:53:35] Epoch:150/150, iteration:141/390, lr: 0.0010000000474974513, loss: 0.0019254654180258512, time: 21.395227670669556
[2025-04-12_18:53:57] Epoch:150/150, iteration:241/390, lr: 0.0010000000474974513, loss: 0.0017952832859009504, time: 21.486543655395508
[2025-04-12_18:54:18] Epoch:150/150, iteration:341/390, lr: 0.0010000000474974513, loss: 0.0019630687311291695, time: 21.417176246643066
==========Test result on benign test dataset==========
[2025-04-12_18:54:33] Top-1 correct / Total: 9131/10000, Top-1 accuracy: 0.9131, Top-5 correct / Total: 9959/10000, Top-5 accuracy: 0.9959, mean loss: 0.2989856004714966, time: 15.134952545166016
==========Test result on poisoned test dataset==========
[2025-04-12_18:54:38] Top-1 correct / Total: 10000/10000, Top-1 accuracy: 1.0, Top-5 correct / Total: 10000/10000, Top-5 accuracy: 1.0, mean loss: 0.00033667180105112493, time: 19.874751806259155
  • 良性测试集结果

    Top-1 correct / Total: 8156/10000, Top-1 accuracy: 0.8156
    Top-5 correct / Total: 9887/10000, Top-5 accuracy: 0.9887
    mean loss: 0.5914474725723267
    
    • Top-1准确率: 81.56% - 在10000张正常测试图像中,模型正确分类了8156张
    • Top-5准确率: 98.87% - 真实标签在模型给出的前5个可能类别中的比例
    • 平均损失值: 0.591 - 损失函数的平均值,反映模型在预测上的不确定性
  • 被毒化测试集结果

    Top-1 correct / Total: 9983/10000, Top-1 accuracy: 0.9983
    Top-5 correct / Total: 10000/10000, Top-5 accuracy: 1.0
    mean loss: 0.004711855202913284
    
    • Top-1准确率: 99.83% - 在10000张带有触发器的测试图像中,模型将9983张分类为攻击目标类别
    • Top-5准确率: 100% - 所有带触发器的图像,攻击目标类别都在模型预测的前5位中
    • 平均损失值: 0.0047 - 非常低的损失值,表明模型非常"自信"地进行了预测

结果表明:

  1. 攻击有效性: 当图像中包含触发器时,模型以99.83%的高准确率将其识别为目标类别,远高于对正常图像的准确率。
  2. 隐蔽性: 模型在正常图像上保持了81.56%的准确率,这个性能虽然不是最优但仍然合理,不会引起明显怀疑。
  3. 高置信度: 毒化数据的损失值(0.0047)远低于正常数据(0.591),表明模型对含触发器的图像分类非常"确信"。

防御方法(REFINE)

原理

REFINE 由两个基本部分组成:

  • Input Transformation Module $\mathcal{T}$(输入转换模块):通过破坏输入样本的 benign(良性) 和 backdoor 模式,生成新的 benign(良性) 特征。
  • Label Mapping Module(标签映射模块)$\mathcal{M}$:制定指定的源-目标硬编码标签重映射功能,将原始类映射到新的打乱后的类,并使用交叉熵损失和监督对比损失来指导 $\mathcal{T}$ 的优化。

Input Transformation Module

为了有效地改变输入样本中潜在的触发模式,需要修改原始模型的输入域。REFINE 使用可训练的自动编码器 Unet 作为输入转换模块的基础结构。该模块不仅保持了转化前后样本尺寸的一致性,而且比传统的重编程方法更灵活。输入转换模块会对图像中的像素特征进行编码,然后对其进行解码以生成新的样本。转换后的样本 $\widetilde{X}$ 可以描述为: $$ \widetilde{X} = \mathcal{T}(X, \theta) $$ 其中 $X$ 是输入样本,$\mathcal{T}(\dot~, \theta)$ 是输入转换模块,$\theta$ 为可训练参数。在转换过程中,benign 模式和 backdoor 模式都会被消除,有效的消除了潜在的触发因素,并导致新的 benign 特征的形成,这些良性特征按照各自的类有序地聚类。

Output Mapping Module

一旦输入样本通过输入转换模块转换为新样本,随后使用原始后门模型处理,该模型为每个类生成置信度分数,如下所示: $$ \widetilde{Y} = \mathcal{F}(\widetilde{X}) $$ 其中 $\mathcal{F}(·)$ 是原始的后门模型,修复模型的输出域会导致模型效用和防御性能之间的权衡,为了解决这个问题,我们在模型的输出端引入了一个输出映射模块,旨在改变输出域并缓解模型效用下降的问题。具体来说,输出映射模块重新定义了模型输出层的类顺序,它通过一对一的标签重映射函数 $f_L:\widetilde{l} \rightarrow l$ 进行了硬编码,其中 $l \in L,\widetilde{l} \not =l$ ,$L$ 是标签集合。原始模型生成的置信度分数可以通过重新映射到新的分数 $\mathcal{M}$ ,如下所示: $$ Y = \mathcal{M}({\widetilde{Y}}) $$

Optimizing REFINE Modules

为了在保持原始模型的准确性的同时去除触发模式,REFINE合并了两个关键的损失函数,即交叉熵损失函数和监督对比函数,用来优化输入转换模块。

$$ \min_{\theta} \mathcal{L}_{refine} = \mathcal{L}_{ce}+\lambda\mathcal{L}_{sup} $$

其中 $\mathcal{L}_{ce}$ 表示交叉熵损失,$\mathcal{L}_{sup}$ 表示监督对比损失,$\lambda$ 是 temperature parameter,添加 $\mathcal{L}_{ce}$ 以保持模型的效用。

在计算 loss 之前,需要获取原始模型预测的当前未标记样本 $X$ 的伪标签 $\overline{Y}$(无需任何额外的模块),如下所示: $$ \overline{Y} = \arg \max(\mathcal{F}(X)) $$ 利用交叉熵损失函数来维持效用。由于输出映射模块促进了对原始模型输出域的大量修改,因此输入转换模块不再受到保留样本原始良性特征的要求的约束。然而,模型必须在新的输出域中保持其原始性能,这需要采用交叉熵损失来有效指导样本转换过程。交叉熵损失通常形式化如下:

$$ \mathcal{L}_{ce} = -\frac{1}{N} \sum_{i=1}^{N} \bar{y}_i \log(y_i), $$

其中 $N$ 表示当前数据批次 $X$ 中的样本数,$\overline{y}_i \in \overline{Y}$ 表示样本 $\overline{x}_i \in \overline{X}$ 的伪标签,$y_i \in Y$ 表示 $x_i$ 通过输出转换模块重新映射的预测概率。

利用监督对比损失来增强后门防御。可以说,仅仅依靠交叉熵损失不足以维持原始模型的良性准确性并减轻后门。因此,REFINE 引入了监督对比损失,其中“监督”是指原始模型作为监督者。具体来说,监督对比损失旨在确保来自同一类别的转换样本的特征更相似,而来自不同类别的样本的特征相距更远。可以定义如下。

$$ \mathcal{L}_{sup} = \sum_{i \in I} \frac{-1}{|P(i)|} \sum_{p \in P(i)} \log \left( \frac{\exp(\tilde{x}_i \cdot \tilde{x}_p / \tau)}{\sum_{a \in A(i)} \exp(\tilde{x}_i \cdot \tilde{x}_a / \tau)} \right), $$

REFINE 源码分析

Input Transformation Module

def forward(self, image):
    self.X_adv = torch.clamp(self.unet(image), 0, 1) # 使用 Unet 进行编码
    # self.X_adv = F.normalize(self.X_adv)
    self.Y_adv = self.net(self.X_adv)                # 使用后门模型进行预测
    Y_adv = F.softmax(self.Y_adv, 1)                 
    return self.label_shuffle(Y_adv)				 # 应用标签重映射
	# return Y_adv

Output Mapping Module

 def train(self):
        for epoch in range(args.epoch):
            self.program.unet.train()
            self.log(f'----- Epoch: {epoch+1}/{args.epoch} -----')
            for image, label in self.train_loader:
                images = image.cuda()

                bsz = label.shape[0]
                f_logit = self.program.net(images) # 原始模型输出

                f_index = f_logit.argmax(1) # 获取预测类别
                # 转换为 one-hot 编码
                f_label = torch.zeros_like(f_logit).cuda().scatter_(1, f_index.view(-1, 1), 1) 
				
                logit = self.program(images) # 经过UNet转换后的输出

                features = self.program.X_adv.view(bsz, -1)
                features = F.normalize(features, dim=1)

                f1, f2 = features, features
                features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
                supconloss = self.supconlossfunc(features, f_index)
                # 交叉熵损失函数和监督对比函数结合
                self.loss = self.lossfunc(logit, f_label) + args.sup * supconloss #
                
                self.optim.zero_grad()
                self.loss.backward()
                self.optim.step()

            self.lr_scheduler.step()
            self.log('[%s] Epoch: %d/%d, lr: %lf, loss: %lf' % 
                     (time.strftime('%Y-%m-%d_%H:%M:%S', time.localtime()), epoch+1, args.epoch, self.optim.param_groups[0]['lr'], self.loss))
           
            if epoch > args.epoch - 10:
                torch.save(self.program.unet.state_dict(), 
                           os.path.join(self.save_dir, f'unet_epoch{epoch+1}.pth'))

            self.program.unet.eval()
            if epoch % 10 == 0 or epoch > args.epoch - 10:
                with torch.no_grad():
                    self.test()

验证函数

分别在 benign 和 poisoned 数据集上进行验证:

    def validate(self):
        top1 = 0
        for image, label in self.test_loader:
            image = image.cuda()
            # image = 0.95 * image + 0.05 * torch.rand(size=image.shape, device='cuda')
            out = self.program(image)
            # if k == 1:
            #     transforms.ToPILImage()(image[0]).save('./pics/benign_image.png')
            #     transforms.ToPILImage()(self.program.X_adv[0]).save('./pics/benign_x_adv.png')
            pred = out.detach().cpu().numpy().argmax(1)
            top1 += sum(label.numpy() == pred)
        acc = top1 / float(args.batch_size * len(self.test_loader))
        self.log('==========Test result on benign test dataset==========')
        self.log('[%s] Top-1 correct / Total: %d/%d, Top-1 accuracy: %.6f' % 
                 (time.strftime('%Y-%m-%d_%H:%M:%S', time.localtime()), top1, args.batch_size * len(self.test_loader), acc))

    def validate_poisoned(self):
        top1 = 0
        top1_1 = 0
        for image, label in self.poisoned_test_loader:
            image = image.cuda()
            out = self.program(image)
            # if k == 1:
            #     transforms.ToPILImage()(image[0]).save('./pics/poisoned_image.png')
            #     transforms.ToPILImage()(self.program.X_adv[0]).save('./pics/poisoned_x_adv.png')
            pred = out.detach().cpu().numpy().argmax(1)
            top1 += sum(label.numpy() == pred)
            # asr_1 += np.sum(pred == 1)
            top1_1 += np.sum(pred == args.arr_shuffle[0])
        asr = top1 / float(args.batch_size * len(self.poisoned_test_loader))
        asr_1 = top1_1 / float(args.batch_size * len(self.poisoned_test_loader))
        self.log('==========Test result on poisoned test dataset==========')
        self.log('[%s] Top-1 correct / Total: %d/%d, Top-1 accuracy: %.6f' % 
                 (time.strftime('%Y-%m-%d_%H:%M:%S', time.localtime()), top1, args.batch_size * len(self.poisoned_test_loader), asr))

防御效果

gpu_id:0
model:ResNet18
dataset:CIFAR10
attack:BadNets
tlabel:0
refine_res:refine_res
lmd:0
mse:0.03
sup:0.1
lr:0.01
optim:SGD
decay:0.8
gamma:0.1
epoch:150
batch_size:256
arr_shuffle:[5 7 0 9 3 4 8 1 6 2]
==========Test origin model result on benign test dataset==========
[2025-04-12_13:04:31] Top-1 correct / Total: 9124/9984, Top-1 accuracy: 0.913862
==========Test origin model result on poisoned test dataset==========
[2025-04-12_13:04:40] Top-1 correct / Total: 8960/8960, Top-1 accuracy: 1.000000
----- Epoch: 1/150 -----
[2025-04-12_13:04:57] Epoch: 1/150, lr: 0.010000, loss: 0.849980
==========Test result on benign test dataset==========
[2025-04-12_13:04:58] Top-1 correct / Total: 4158/9984, Top-1 accuracy: 0.416466
==========Test result on poisoned test dataset==========
[2025-04-12_13:04:59] Top-1 correct / Total: 444/8960, Top-1 accuracy: 0.049554
----- Epoch: 2/150 -----
[2025-04-12_13:05:16] Epoch: 2/150, lr: 0.010000, loss: 0.774096
......
----- Epoch: 149/150 -----
[2025-04-12_13:47:39] Epoch: 149/150, lr: 0.006400, loss: 0.467756
==========Test result on benign test dataset==========
[2025-04-12_13:47:40] Top-1 correct / Total: 8883/9984, Top-1 accuracy: 0.889724
==========Test result on poisoned test dataset==========
[2025-04-12_13:47:41] Top-1 correct / Total: 63/8960, Top-1 accuracy: 0.007031
----- Epoch: 150/150 -----
[2025-04-12_13:47:58] Epoch: 150/150, lr: 0.006400, loss: 0.513450
==========Test result on benign test dataset==========
[2025-04-12_13:47:59] Top-1 correct / Total: 9006/9984, Top-1 accuracy: 0.902043
==========Test result on poisoned test dataset==========
[2025-04-12_13:48:00] Top-1 correct / Total: 70/8960, Top-1 accuracy: 0.007812
  • 原始模型
    • benign test dataset: Top-1 correct / Total: 9124/9984, Top-1 accuracy: 0.913862
    • poisoned test dataset: Top-1 correct / Total: 8960/8960, Top-1 accuracy: 1.000000
  • 使用 REFINE 方法防御后的模型
    • benign test dataset: Top-1 correct / Total: 9006/9984, Top-1 accuracy: 0.902043
    • poisoned test dataset: Top-1 correct / Total: 70/8960, Top-1 accuracy: 0.007812

结果表明:

  1. 防御有效性
    • 攻击成功率从 100% 降至 0.78%,防御成功率达 99.22%
    • 这表明 REFINE 几乎完全消除了后门漏洞
  2. 性能保持
    • 在正常样本上准确率仅下降 1.19%
    • 这个微小的性能损失是可以接受的,尤其是考虑到获得的安全性提升