当前位置:网站首页>Record the training process

Record the training process

2022-06-25 20:06:00 Orange cedar

train_curve = list()
def train_net(net, device, data_path, epochs=100, batch_size=4, lr=0.01):
    #  Load training set 
    isbi_dataset = ISBI_Loader(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=batch_size, 
                                               shuffle=True)
    #  Definition RMSprop Algorithm 
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    #  Definition Loss Algorithm 
    criterion = nn.BCEWithLogitsLoss()
    # best_loss Statistics , Initialize to positive infinity 
    best_loss = float('inf')
    #  Training epochs Time 
    step = 0
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs))
        print('-' * 10)
        dt_size = len(train_loader.dataset)
        epoch_loss = 0

        #  Training mode 
        net.train()
        #  according to batch_size Start training 
        for image, label in train_loader:
            step += 1
            #  Copy data to device in 
            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)
            # zero the parameter gradients
            optimizer.zero_grad()
            #  Use network parameters , Output forecast results 
            pred = net(image)
            #  Calculation loss
            loss = criterion(pred, label)
            #  Update parameters 
            # mwg save image, gt and predition
            if step % 8000 == 0:
                s_img = image[0].permute(1, 2, 0).cpu().numpy().astype(np.int)
                s_lab = np.array(label[0].permute(1, 2, 0).cpu().numpy()*255, np.int)
                s_pre = np.array(torch.sigmoid(pred)[0].permute(1, 2, 0).cpu().detach().numpy()*255, np.int)
                s_save = np.concatenate([s_img, s_lab, s_pre], axis=1)
                cv2.imwrite('logs/liver1/' + str(step) + '.png', s_save)

            loss.backward()
            optimizer.step()
            print('Epoch {}/{}'.format(epoch, epochs))
            epoch_loss += loss.item()
            # train_curve.append(loss.item())
            #  preservation loss The network parameter with the smallest value 
        print('Epoch {}/{}'.format(epoch, epochs), '   Loss/train:%0.3f' %(epoch_loss / step))
        train_curve.append(epoch_loss / step)
        # print("epoch %d loss:%0.3f" % (epoch, epoch_loss / step))
        if (epoch + 1) % 20 == 0:
                torch.save(net.state_dict(), './Pth/liver1/Liver_U_Net_%d.pth' % (epoch + 1))
    train_x = range(len(train_curve))
    train_y = train_curve
    train_iters = len(train_loader)
    plt.plot(train_x, train_y, label='Train')
    plt.legend(loc='upper right')
    plt.ylabel('loss value')
    plt.xlabel('Iteration')
    plt.show()
    plt.savefig(fname="liver1_loss.png", figsize=[10, 10])

原网站

版权声明
本文为[Orange cedar]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/02/202202190507547544.html