《PyTorch深度学习实践》- 刘二大人p7 - 处理多维特征的输入

处理多维特征的输入(Multiple_Dimension_Input)

《PyTorch深度学习实践》- 刘二大人p7

多维特征的输入

xSPi2F.png

这是一份用各项指标来检测是否患有糖尿病的数据库,每行x1~x8共有8个输入,可看作身体8个指标,输出Y=0,则不患糖尿病;y=1,则患有糖尿病。

## 逻辑斯蒂回归模型更新

既然从一维输入变为八维输入,则Loss模型也将相应的变换。

xSPXRO.png

每个输入xn都要乘上权重wn,最后再加上一个偏置量b。

公式详解:

xSiZQg.png

 

Mini-Batch

xSiHXQ.png

降维

代码中这一行的意思是指输入8维,输出1维。

1
self.linear = torch.nn.Linear(8, 1) 

但在实际中将使用降维,如下,从8维降到6维,再降到4维,最后再降到1维输出,可以节约内存。xSFfu4.png

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import numpy as np
from matplotlib import pyplot as plt

xy = np.loadtxt('diabetes.csv.gz', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1]) # 取数据表格前八列
y_data = torch.from_numpy(xy[:, [-1]]) # 取数据表格最后一列


class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
# self.linear = torch.nn.Linear(8, 1) # 指输入8维,输出1维
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.sigmoid = torch.nn.Sigmoid()

def forward(self, x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x

model = Model()

criterion = torch.nn.BCELoss(reduction='mean') # 损失loss求平均
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

epoch_list = []
loss_list = []
for epoch in range(1000):
# forward
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print('epoch = ',epoch,'loss = ', loss.item())

# backward
optimizer.zero_grad()
loss.backward()

# update
optimizer.step()

epoch_list.append(epoch)
# loss_list.append(loss)
loss_list.append(loss.item())

plt.plot(epoch_list,loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

结果图

xSA4t1.png

更换激活函数

xSAqne.png

之前使用的的是Logistic(sigmoid) Function,现选择几个激活函数

ReLu

代码更换处

xSuLx1.png

但考虑到ReLu的特性,最后x小于0时,还使用ReLu的话y-hat将变为0,loss中的logy-hat变为log0,所以最后一行使用sigmod,使结果图曲线更加平滑

结果图

xSKQRs.png

文章作者: CasimiBreidin
文章链接: https://blognotes.cn/posts/41853.html
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Casimi’Blog