当前位置:网站首页>【Pytorch】nn.Module
【Pytorch】nn.Module
2022-07-24 07:34:00 【rejudge】
import torch
from torch import nn
device="cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {
device} device")
Using cuda device
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork,self).__init__()
self.flatten=nn.Flatten()
self.linear_relu_stack=nn.Sequential(
nn.Linear(28*28,512),
nn.ReLU(),
nn.Linear(512,512),
nn.ReLU(),
nn.Linear(512,10),
)
def forward(self,x):
x=self.flatten(x)
logits=self.linear_relu_stack(x)
return logits
model=NeuralNetwork().to(device)
model
NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
)
input_image=torch.rand(3,28,28)
input_image=input_image.to("cuda")
input_image.size()
torch.Size([3, 28, 28])
logits=model(input_image)
logits.shape
torch.Size([3, 10])
softmax=nn.Softmax(dim=1)
softmax(logits)
tensor([[0.1017, 0.0993, 0.1022, 0.0997, 0.1048, 0.1015, 0.0971, 0.0979, 0.0961,
0.0996],
[0.1023, 0.1030, 0.0998, 0.1008, 0.1009, 0.1034, 0.1012, 0.0952, 0.0925,
0.1010],
[0.1048, 0.1021, 0.1008, 0.1018, 0.1034, 0.1022, 0.0991, 0.0925, 0.0965,
0.0970]], device=‘cuda:0’, grad_fn=< SoftmaxBackward >)
边栏推荐
- Influxdb unauthorized access & CouchDB permission bypass
- Jenkins 详细部署
- Problems encountered in inserting large quantities of data into the database in the project
- Buddy: core function entry
- 24. Global event bus
- Unable to auto assemble, bean of type "redistemplate" not found
- 深度学习二三事-回顾那些经典卷积神经网络
- [line test] Figure finding regular questions
- Jackson parsing JSON detailed tutorial
- Introduction to C language v First understanding pointer VI. first understanding structure
猜你喜欢

Oauth2==SSO三种协议。Oauth2四种模式

Feature Selective Anchor-Free Module for Single-Shot Object Detection

Requests crawler implements a simple web page collector

Filter filter

Requests crawl page source code data

mysql查询当前节点的所有父级

Introduction to C language III Array 4. Operators

Feature Selective Anchor-Free Module for Single-Shot Object Detection

Jenkins 详细部署

QoS quality of service three DiffServ Model message marking and PHB
随机推荐
[line test] Figure finding regular questions
项目上线就炸,这谁受得了
numpy.cumsum
全国职业院校技能大赛网络安全B模块 缓冲区溢出漏洞
Learning notes - distributed transaction theory
Vulnhub DC1
stdafx.h 简介及作用
There are two tables in Oracle, a and B. these two tables need to be associated with the third table C. how to update the field MJ1 in table a to the value MJ2 in table B
B. Also Try Minecraft
JS_实现多行文本根据换行分隔成数组
Source code analysis of Nacos configuration center
JS的DOM操作——style的操作
Harbor2.2 quick check of user role permissions
CSDN, it's time to say goodbye!
R语言手写数字识别
[FreeRTOS] 11 software timer
django.db.utils. OperationalError: (2002, “Can‘t connect to local MySQL server through socket ‘/var/r
Bookkeeping app: xiaoha bookkeeping 1 - production of welcome page
Li Kou, niuke.com - > linked list related topics (Article 1) (C language)
MySQL queries all parents of the current node