当前位置:网站首页>加载本地cifar10 数据集
加载本地cifar10 数据集
2022-06-25 15:36:00 【全栈程序员站长】
大家好,又见面了,我是你们的朋友全栈君。
由于我们使用官方的导入cifar10数据集方法不成功,在知道cifar10数据集的本地路径的情况下,可以通过以下方法进行导入:
import tensorflow as tf
import numpy as np
import math
from six.moves import cPickle as pickle
import os
import platform
from subprocess import check_output
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def load_pickle(f):
version = platform.python_version_tuple()
if version[0] == '2':
return pickle.load(f)
elif version[0] == '3':
return pickle.load(f, encoding='latin1')
raise ValueError("invalid python version: {}".format(version))
def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename, 'rb') as f:
datadict = load_pickle(f)
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000,3072)
Y = np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
""" load all of cifar """
xs = []
ys = []
for b in range(1,6):
f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
X, Y = load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
Xtr = np.concatenate(xs)
Ytr = np.concatenate(ys)
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return Xtr, Ytr, Xte, Yte
def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000):
# Load the raw CIFAR-10 data
cifar10_dir = '../input/cifar-10-batches-py/'
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
# Subsample the data
mask = range(num_training, num_training + num_validation)
X_val = X_train[mask]
y_val = y_train[mask]
mask = range(num_training)
X_train = X_train[mask]
y_train = y_train[mask]
mask = range(num_test)
X_test = X_test[mask]
y_test = y_test[mask]
x_train = X_train.astype('float32')
x_test = X_test.astype('float32')
x_train /= 255
x_test /= 255
return x_train, y_train, X_val, y_val, x_test, y_test
# Invoke the above function to get our data.
x_train, y_train, x_val, y_val, x_test, y_test = get_CIFAR10_data()
print('Train data shape: ', x_train.shape)
print('Train labels shape: ', y_train.shape)
print('Validation data shape: ', x_val.shape)
print('Validation labels shape: ', y_val.shape)
print('Test data shape: ', x_test.shape)
print('Test labels shape: ', y_test.shape)
参考:
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/152110.html原文链接:https://javaforall.cn
边栏推荐
- TFIDF and BM25
- Arthas source code learning-1
- Several solutions to the distributed lock problem in partial Internet companies
- 分享自己平时使用的socket多客户端通信的代码技术点和软件使用
- Jz-065 path in matrix
- 通过客户经理的开户链接开股票账户安全吗?
- 原生js动态添加元素
- JS的遍历和分支判断(2022年6月24日案例)
- MySQL transaction characteristics and implementation principle
- 剑指 Offer 03. 数组中重复的数字
猜你喜欢
[paper notes] street view change detection with deconvolutional networks
Download and installation tutorial of consumer
Globally unique key generation strategy - implementation principle of the sender
Sword finger offer 10- I. Fibonacci sequence
Report on Hezhou air32f103cbt6 development board
Why is it said that restarting can solve 90% of the problems
Kali modify IP address
Sword finger offer 09 Implementing queues with two stacks
AspNetCore&云效Flow持续集成
不要小看了积分商城,它的作用可以很大!
随机推荐
golang reverse a slice
Several relationships of UML
What is the safest app for stock account opening? Tell me what you know
Desktop development (Tauri) opens the first chapter
MySQL modifier l'instruction de champ
程序员 VS 黑客的思维 | 每日趣闻
[paper notes] mcunetv2: memory efficient patch based influence for tiny deep learning
在打新债开户证券安全吗,需要什么准备
CV pre training model set
Go language template text/template error unexpected EOF
Completabilefuture of asynchronous tools for concurrent programming
Continuous integration of aspnetcore & cloud flow
Is it safe to open a stock account in Guoxin golden sun?
[paper notes] street view change detection with deconvolutional networks
golang reverse a slice
Summary of four parameter adjustment methods for machine learning
《睡眠公式》:怎么治睡不好?
Distributed token
李飞飞团队将ViT用在机器人身上,规划推理最高提速512倍,还cue了何恺明的MAE
JS的注释