Java实现深度神经网络:Linear层
Linear .java
/**
* Linear 层
* @author SXC 2020年9月11日 上午10:39:35
*/
public class Linear {// out_features个神经元
double[][] weight;// 每行为一个神经元的weight
double[] bias;
int in_features, out_features;
public Linear(int in_features, int out_features) {
weight = new double[out_features][in_features];
bias = new double[out_features];
this.in_features = in_features;
this.out_features = out_features;
}
public void setBias(double[] bias) {
this.bias = bias;
if (bias.length!=out_features) {
System.out.println("偏置输入维数不符合要求,应输入维数:"+out_features);
}
}
public void setWeight(double[][] weight) {
this.weight = weight;
if (weight[0].length!=in_features) {
System.out.println("权重输入列数不符合要求,应输入维数:"+in_features);
}
if (weight.length!=out_features) {
System.out.println("权重输入行数不符合要求,应输入维数:"+out_features);
}
}
public double[] getBias() {
return bias;
}
public double[][] getWeight() {
return weight;
}
public double[] out(double[] x) {
double[] out = new double[out_features];
for (int i = 0; i < out_features; i++) {
for (int j = 0; j < in_features; j++) {
out[i] += weight[i][j]*x[j];
}
out[i] += bias[i];
}
return out;
}
}
测试代码
public class t {
public static void main(String[] args) {
Linear linear=new Linear(3, 2);
double[][] weight= {{1,2,3},{3,4,3}};
double[] bias= {9,5};
linear.setBias(bias);
linear.setWeight(weight);
double[] x= {5,7,3};
double[] out= linear.out(x);
System.out.println(out[0]+" " +out[1]);
}
测试结果
37.0 57.0
import torch as t
linear=t.nn.Linear(3,2)
input=t.tensor([5,7,3],dtype=t.float32)
weight=t.tensor([[1,2,3],[3,4,3]],dtype=t.float32)
bias=t.tensor([9,5],dtype=t.float32)
linear.weight=t.nn.Parameter(weight)
linear.bias=t.nn.Parameter(bias)
print(linear(input))
运行结果
H:\ProgramData\Anaconda3\python.exe D:/PycharmProjects/untitled/t.py
tensor([37., 57.], grad_fn=<AddBackward0>)
Process finished with exit code 0
结果相同,达到目标
因篇幅问题不能全部显示,请点此查看更多更全内容