当前位置:网站首页>Reproduce various counter attack methods
Reproduce various counter attack methods
2022-07-23 15:58:00 【swpu_ jx_ one thousand nine hundred and ninety-eight】
import torch
from torch import nn
import matplotlib.pyplot as plt
def clip(x, x_, eps):
mask = torch.ones_like(x)
lower_clip = torch.max(torch.stack([mask * 0, x - eps, x_]), dim=0)[0]
return torch.min(torch.stack([mask, x + eps, lower_clip]), dim=0)[0]
def train_adv_examples(
model: nn.Module, loss_fct: callable, adv_examples: torch.Tensor, adv_targets: torch.Tensor,
epochs: int = 10, alpha: float = 1.0, clip_eps: float = (1 / 255) * 8, do_clip: bool = False, minimize: bool = False
):
model.eval()
for e in range(epochs):
adv_examples.requires_grad = True
model.zero_grad()
adv_out = model(adv_examples)
loss = loss_fct(adv_out, adv_targets)
loss.backward()
adv_grad = adv_examples.grad
adv_examples = adv_examples.detach()
direction = -1 if minimize else 1
adv_sign_grad = adv_examples + direction * alpha * adv_grad.sign()
if do_clip:
adv_examples = clip(adv_examples, adv_sign_grad, clip_eps)
else:
adv_examples = adv_sign_grad
return adv_examples
def train_adv_fgsm(
model: nn.Module, loss_fct: callable, adv_examples: torch.Tensor, adv_targets: torch.Tensor,
epochs: int = 10, alpha: float = 0.1
):
return train_adv_examples(
model, loss_fct, adv_examples, adv_targets,
epochs=epochs, alpha=alpha, do_clip=False, minimize=False
)
def train_adv_bim(
model: nn.Module, loss_fct: callable, adv_examples: torch.Tensor, adv_targets: torch.Tensor,
epochs: int = 10, alpha: float = 1.0, clip_eps: float = (1 / 255) * 8
):
return train_adv_examples(
model, loss_fct, adv_examples, adv_targets,
epochs=epochs, alpha=alpha, do_clip=True, clip_eps=clip_eps, minimize=False
)
def train_adv_cw(
model: nn.Module, adv_examples: torch.Tensor, adv_target: int = 3, iteration: int = 5000, lr: float = 0.01, c:float = 1
):
def create_single_adv(model: nn.Module, adv_examples: torch.Tensor, adv_target: int = 3, iteration: int = 5000, lr: float = 0.01, c:float = 1):
box_max = 1
box_min = 0
box_mul = (box_max - box_min)/2
box_plus = (box_min + box_max)/2
modifier = torch.zeros_like(adv_examples, requires_grad=True)
l2dist_list = []
loss2_list = []
loss_list = []
model.eval()
adv_examples_c = torch.arctanh((adv_examples - box_plus)/box_mul * 0.99999)
for i in range(iteration):
new_example = torch.tanh(adv_examples + modifier)*box_mul + box_plus
l2dist = torch.dist(new_example, adv_examples, p=2)
output = model(new_example)
# Set an attack target
onehot = torch.zeros_like(output)
onehot[:, adv_target] = 1
others = torch.max((1-onehot)*output, dim=1).values
real = torch.sum(output*onehot, dim=1)
loss2 = torch.sum(torch.maximum(torch.zeros_like(others) - 0.01, others - real))
loss = l2dist + c*loss2
l2dist_list.append(l2dist)
loss2_list.append(loss2)
loss_list.append(loss)
if modifier.grad is not None:
modifier.grad.zero_()
loss.backward()
modifier = (modifier - modifier.grad*lr).detach()
modifier.requires_grad = True
def plot_loss(loss, loss_name):
plt.figure()
plt.plot([i for i in range(len(loss))], [i.detach().numpy() for i in loss])
# plt.yticks(np.arange(1,50,0.5))
plt.xlabel('iteration times')
plt.ylabel(loss_name)
plt.show()
plot_loss(l2dist_list, 'l2 distance loss')
plot_loss(loss2_list, 'category loss')
plot_loss(loss_list, 'all loss')
new_img = torch.tanh(adv_examples + modifier) * box_mul + box_plus
return new_img
adv_list = []
# for i in adv_examples:
return create_single_adv(model,adv_examples,adv_target,iteration,lr)
def train_adv_least_likely(
model: nn.Module, loss_fct: callable, adv_examples: torch.Tensor,
epochs: int = 10, alpha: float = 0.1, clip_eps: float = (1 / 255) * 8
):
model.eval()
adv_targets = model(adv_examples).argmin(dim=1).detach()
return train_adv_examples(
model, loss_fct, adv_examples, adv_targets,
epochs=epochs, alpha=alpha, do_clip=True, clip_eps=clip_eps, minimize=True
)
边栏推荐
- Ultra detailed MP4 format analysis
- Deep understanding of L1 and L2 regularization
- Quickly master QML Chapter 5 components
- 【2023提前批 之 面经】~ 京东方
- [hiflow] regularly send Tencent cloud SMS sending group
- [pyGame actual combat] aircraft shooting masterpiece: fierce battle in the universe is imminent... This super classic shooting game should also be taken out and restarted~
- C语言经典例题-switch case语句转换日期格式
- Application of ERP management system in equipment manufacturing enterprise management
- Google Earth Engine——影像统计过程中出现的空值问题
- [attack and defense world web] difficulty Samsung 9 points introductory question (Part 2): shrink, lottery
猜你喜欢

SCA在得物DevSecOps平台上应用

《快速掌握QML》第五章 组件

C语言经典例题-用4×4矩阵显示从1到16的所有整数,并计算每行、每列和每条对角线上的和

适用于顺序磁盘访问的1分钟法则

Vim到底可以配置得多漂亮?

Backup content hahaha

【攻防世界WEB】难度三星9分入门题(中):ics-05、easytornado

云服务器ECS远程监控

Application of ERP management system in equipment manufacturing enterprise management

Expression du suffixe (une question par jour pendant les vacances d'été 4)
随机推荐
链表合并(暑假每日一题 3)
SCA在得物DevSecOps平台上应用
对专利的学习
后缀表达式(暑假每日一题 4)
Expression du suffixe (une question par jour pendant les vacances d'été 4)
10100
C# 关闭当前电脑指令
Gear 月度更新|6 月
2022最NB的JVM基础到调优笔记,吃透阿里P6小case
[untitled]
C language learning notes
C语言经典例题-switch case语句转换日期格式
SharedPreferences数据储存
Six ways of uniapp route jump
Go: Gin urlencoded format
Exclusive interview | open source Summer Star Niu Xuewei
C # close current computer command
UmiJs - qiankun主子应用之间,数据的传递
Gear monthly update June
Vim到底可以配置得多漂亮?