pytorch拟合函数

pytorch拟合函数pytorch拟合一元一次函数1.这个程序定义网络的方式很特别2.使用卷积拟合  拟合函数y=a×x+by=a\timesx+by=a×x+b,其中a=1,b=2a=1,b=2a=1,b=2。1.这个程序定义网络的方式很特别importtorchimportnumpyasnpclassNet:def__init__(self):self….

欢迎访问我的博客首页


1. 拟合一元一次函数


  拟合函数 y = a × x + b y=a\times x+b y=a×x+b,其中 a = 1 , b = 2 a=1,b=2 a=1,b=2

1.1 自定义网络


import torch
import numpy as np

class Net:
    def __init__(self):
        self.a = torch.rand(1, requires_grad=True)
        self.b = torch.rand(1, requires_grad=True)
        self.__parameters = dict(a=self.a, b=self.b)
        self.___gpu = False
    def forward(self, inputs):
        return self.a * inputs + self.b
    def parameters(self):
        for name, param in self.__parameters.items():
            yield param

if __name__ == '__main__':
    x = np.linspace(1, 50, 50)
    y = x + 2
    x = torch.from_numpy(x.astype(np.float32))
    y = torch.from_numpy(y.astype(np.float32))
    net = Net()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0005)
    loss_op = torch.nn.MSELoss(reduction='sum')
    for i in range(1, 20001, 1):
        out = net.forward(x)
        loss = loss_op(y, out)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 输出中间过程
        loss_numpy = loss.cpu().detach().numpy()
        if i % 1000 == 0:
            print(i, loss_numpy)
        if loss_numpy < 0.00001:
            a = net.a.cpu().detach().numpy()
            b = net.b.cpu().detach().numpy()
            print(a, b)
            exit()

  这种方法定义网络时没有继承torch.nn.Module,完全自己写了一个网络,要显式调用Net的forward函数。损失函数使用的是L2损失。

1.2 使用卷积网络


import torch
import numpy as np

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        layers = []
        layers.append(torch.nn.Conv2d(1, 1, kernel_size=1, stride=1, bias=True))
        self.net = torch.nn.ModuleList(layers)
    def forward(self, x):
        return self.net[0](x)

if __name__ == '__main__':
    x = np.linspace(1, 50, 50)
    y = x + 2  # a = 1, b = 2
    x = torch.from_numpy(x.astype(np.float32))
    y = torch.from_numpy(y.astype(np.float32))
    net = Net()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0005)
    loss_op = torch.nn.L1Loss(reduce=True, size_average=True)
    for i in range(20000):
        x_batch = torch.tensor([x[i % 50]]).reshape(1, 1, 1, 1)
        y_batch = torch.tensor([y[i % 50]]).reshape(1, 1, 1, 1)
        out = net(x_batch)
        loss = loss_op(y_batch, out)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 输出中间过程
        loss_numpy = loss.cpu().detach().numpy()
        if i % 1000 == 0:
            print('--iterator:', i, 'loss:', loss_numpy)
        if loss_numpy < 1e-10:
            break
    for k, v in net.named_parameters():
        print(k, v.cpu().detach().numpy())

  这种方法使用卷积网络拟合。使用L2损失拟合的效果很差,这里使用L1作为损失函数。原因是L2度量的是误差的平方,当误差小于1时,L2度量的误差数量级比实际误差的数量级成倍减少。本例 b a t c h _ s i z e = 1 batch\_size=1 batch_size=1,迭代若干次后误差必定小于1,所以使用L1损失在loss达到指定阈值时收敛得更好。

2. 拟合一个数


  随机产生一个0到1的数,计算它与0.5的差的平方,通过梯度把产生的数调到0.5。

2.1 第一种方法


import torch

if __name__ == '__main__':
    x = torch.rand(1, requires_grad=True)
    a = torch.tensor([0.5])
    optimizer = torch.optim.SGD([x], lr=0.05, weight_decay=0.00003)
    loss_fn = torch.nn.MSELoss()
    while True:
        loss = loss_fn(x, a)
        if loss < 1e-10:
            break
        optimizer.zero_grad()
        loss.backward()
        print('-- ', x.tolist(), x.grad.tolist())
        optimizer.step()

  第6行指定优化参数x可以使用任何迭代器,比如元组 ( )、列表 [ ]、集合 { }。还可以使用生成器,如1.1中第12-14行。

2.2 第二种方法


import torch

if __name__ == '__main__':
    x = torch.rand(1, requires_grad=True)
    a = torch.tensor([0.5])
    loss_fn = torch.nn.MSELoss()
    while True:
        loss = loss_fn(x, a)
        if loss < 1e-10:
            break
        loss.backward()
        print('-- ', x.tolist(), x.grad.tolist())
        x.data = x - 0.6 * x.grad
        x.grad.data.zero_()

  这种方法没有使用优化器,直接利用梯度修改x的值。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/34255.html

(0)
编程小号编程小号

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注