纯净、安全、绿色的下载网站

首页|软件分类|下载排行|最新软件|IT学院

当前位置:首页IT学院IT技术

Python pyTorch权重与L2范数正则化 Python深度学习pyTorch权重衰减与L2范数正则化解析

算法菜鸟飞高高   2021-09-30 我要评论
想了解Python深度学习pyTorch权重衰减与L2范数正则化解析的相关内容吗,算法菜鸟飞高高在本文为您仔细讲解Python pyTorch权重与L2范数正则化的相关知识和一些Code实例,欢迎阅读和指正,我们先划重点:Python深度学习,Python,pyTorch权重与L2范数正则化,下面大家一起来学习吧。

在这里插入图片描述

下面进行一个高维线性实验

假设我们的真实方程是:

在这里插入图片描述

假设feature数200,训练样本和测试样本各20个

模拟数据集

num_train,num_test = 10,10
num_features = 200
true_w = torch.ones((num_features,1),dtype=torch.float32) * 0.01
true_b = torch.tensor(0.5)
samples = torch.normal(0,1,(num_train+num_test,num_features))
noise = torch.normal(0,0.01,(num_train+num_test,1))
labels = samples.matmul(true_w) + true_b + noise
train_samples, train_labels= samples[:num_train],labels[:num_train]
test_samples, test_labels = samples[num_train:],labels[num_train:]

定义带正则项的loss function

def loss_function(predict,label,w,lambd):
    loss = (predict - label) ** 2
    loss = loss.mean() + lambd * (w**2).mean()
    return loss

画图的方法

def semilogy(x_val,y_val,x_label,y_label,x2_val,y2_val,legend):
    plt.figure(figsize=(3,3))
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.semilogy(x_val,y_val)
    if x2_val and y2_val:
        plt.semilogy(x2_val,y2_val)
        plt.legend(legend)
    plt.show()

拟合和画图

def fit_and_plot(train_samples,train_labels,test_samples,test_labels,num_epoch,lambd):
    w = torch.normal(0,1,(train_samples.shape[-1],1),requires_grad=True)
    b = torch.tensor(0.,requires_grad=True)
    optimizer = torch.optim.Adam([w,b],lr=0.05)
    train_loss = []
    test_loss = []
    for epoch in range(num_epoch):
        predict = train_samples.matmul(w) + b
        epoch_train_loss = loss_function(predict,train_labels,w,lambd)
        optimizer.zero_grad()
        epoch_train_loss.backward()
        optimizer.step()
        test_predict = test_sapmles.matmul(w) + b
        epoch_test_loss = loss_function(test_predict,test_labels,w,lambd)
        train_loss.append(epoch_train_loss.item())
        test_loss.append(epoch_test_loss.item())
    semilogy(range(1,num_epoch+1),train_loss,'epoch','loss',range(1,num_epoch+1),test_loss,['train','test'])

在这里插入图片描述
可以发现加了正则项的模型,在测试集上的loss确实下降了

以上就是Python深度学习pyTorch权重衰减与L2范数正则化解析的详细内容,更多关于Python pyTorch权重与L2范数正则化的资料请关注其它相关文章!


相关文章

猜您喜欢

  • Java 多线程 Java多线程学习笔记

    想了解Java多线程学习笔记的相关内容吗,四季人06在本文为您仔细讲解Java 多线程的相关知识和一些Code实例,欢迎阅读和指正,我们先划重点:Java,多线程,Java,Thread,Java,Runable,下面大家一起来学习吧。..
  • Spring Boot应用Docker部署 Spring Boot应用通过Docker发布部署的流程分析

    想了解Spring Boot应用通过Docker发布部署的流程分析的相关内容吗,imonkeyi在本文为您仔细讲解Spring Boot应用Docker部署的相关知识和一些Code实例,欢迎阅读和指正,我们先划重点:Spring,Boot应用Docker部署,Spring,Boot应用Docker部署,下面大家一起来学习吧。..

网友评论

Copyright 2020 www.zhuchaoyouxi.com 【筑巢游戏】 版权所有 软件发布

声明:所有软件和文章来自软件开发商或者作者 如有异议 请与本站联系 点此查看联系方式