博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Pytorch设立计算图并自动计算
阅读量:5954 次
发布时间:2019-06-19

本文共 1872 字,大约阅读时间需要 6 分钟。

本博文参考七月在线pytorch课程

1.numpy和pytorch实现梯度下降法

使用numpy实现简单神经网络

import numpy as npN, D_in, H, D_out = 64, 1000, 100, 10# 随机创建一些训练数据x = np.random.randn(N, D_in)y = np.random.randn(N, D_out)w1 = np.random.randn(D_in, H)w2 = np.random.randn(H, D_out)learning_rate = 1e-6for it in range(500):    # Forward pass    h = x.dot(w1) # N * H    h_relu = np.maximum(h, 0) # N * H    y_pred = h_relu.dot(w2) # N * D_out        # compute loss    loss = np.square(y_pred - y).sum()    print(it, loss)        # Backward pass    # compute the gradient    grad_y_pred = 2.0 * (y_pred - y)    grad_w2 = h_relu.T.dot(grad_y_pred)    grad_h_relu = grad_y_pred.dot(w2.T)    grad_h = grad_h_relu.copy()    grad_h[h<0] = 0    grad_w1 = x.T.dot(grad_h)        # update weights of w1 and w2    w1 -= learning_rate * grad_w1    w2 -= learning_rate * grad_w2

使用pytorch实现简单神经网络

N, D_in, H, D_out = 64, 1000, 100, 10# 随机创建一些训练数据x = torch.randn(N, D_in)y = torch.randn(N, D_out)w1 = torch.randn(D_in, H)w2 = torch.randn(H, D_out)learning_rate = 1e-6for it in range(500):    # Forward pass    h = x.mm(w1) # N * H    h_relu = h.clamp(min=0) # N * H    y_pred = h_relu.mm(w2) # N * D_out        # compute loss    loss = (y_pred - y).pow(2).sum().item()    print(it, loss)        # Backward pass    # compute the gradient    grad_y_pred = 2.0 * (y_pred - y)    grad_w2 = h_relu.t().mm(grad_y_pred)    grad_h_relu = grad_y_pred.mm(w2.t())    grad_h = grad_h_relu.clone()    grad_h[h<0] = 0    grad_w1 = x.t().mm(grad_h)        # update weights of w1 and w2    w1 -= learning_rate * grad_w1    w2 -= learning_rate * grad_w2

设定初始值

#numpyx = np.random.randn(N, D_in)y = np.random.randn(N, D_out)w1 = np.random.randn(D_in, H)w2 = np.random.randn(H, D_out)#pytorchx = torch.randn(N, D_in)y = torch.randn(N, D_out)w1 = torch.randn(D_in, H)w2 = torch.randn(H, D_out)

转载于:https://www.cnblogs.com/lky520hs/p/10864952.html

你可能感兴趣的文章
数论之 莫比乌斯函数
查看>>
linux下查找某个文件位置的方法
查看>>
python之MySQL学习——数据操作
查看>>
Harmonic Number (II)
查看>>
长连接、短连接、长轮询和WebSocket
查看>>
day30 模拟ssh远程执行命令
查看>>
做错的题目——给Array附加属性
查看>>
Url.Action取消字符转义
查看>>
JQuery选择器大全
查看>>
Gamma阶段第三次scrum meeting
查看>>
python3之装饰器修复技术@wraps
查看>>
[考试]20150606
查看>>
Javascript_备忘录5
查看>>
Can’t create handler inside thread that has not called Looper.prepare()
查看>>
敏捷开发方法综述
查看>>
Hadoop数据操作系统YARN全解析
查看>>
修改数据库的兼容级别
查看>>
Windows下同时安装两个版本Jdk
查看>>
uoj#228. 基础数据结构练习题(线段树)
查看>>
JS键盘事件监听
查看>>