程序问答   发布时间:2022-06-01  发布网站:大佬教程  code.js-code.com
大佬教程收集整理的这篇文章主要介绍了使用 tfdataset 训练 CNN大佬教程大佬觉得挺不错的,现在分享给大家,也给大家做个参考。

如何解决使用 tfdataset 训练 CNN?

开发过程中遇到使用 tfdataset 训练 CNN的问题如何解决?下面主要结合日常开发的经验,给出你关于使用 tfdataset 训练 CNN的解决方法建议,希望对你解决使用 tfdataset 训练 CNN有所启发或帮助;

我正在尝试使用 TFRecordDataset 训练 CNN(我认为这是无关的,但这是我的情况)并收到以下错误:

ValueError: 维度 0 的切片索引 0 越界。对于'{{节点 strIDed_slice}} = StrIDedSlice[索引=DT_INT32,T=DT_INT32, begin_mask=0,ellipsis_mask=0,end_mask=0,new_axis_mask=0,shrink_axis_mask=1](形状,strIDed_slice/stack,strIDed_slice/stack_1, strIDed_slice/stack_2)' 输入形状:[0],[1],[1] 和 计算输入张量:input[1] =,input[2] =,input[3] = .

举个例子,这是我正在执行的代码:

美国有线电视新闻网:

import tensorflow as tf
def get_cnn_model(input_shape=(31,31,9),n_outputs=4,convolutions=3,optimizer='adam',seed=26):
    tf.random.set_seed(seed=seed)
    _input = layers.input(shape=input_shape,name='input')
    x = layers.Conv2D(64,(4,4),activation='relu',padding='same',name=f'conv_0')(_input)
    x = layers.MaxPooling2D(2)(x)
    for i in range(convolutions - 1):
        x = layers.Conv2D(64,name=f'conv_{i + 1}')(x)
        x = layers.MaxPooling2D(2)(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128,name='dense_1')(x)
    x = layers.Dropout(0.35,name='dropout_1')(x)
    x = layers.Dense(128,name='dense_2')(x)
    x = layers.Dropout(0.35,name='dropout_2')(x)
    p = layers.Dense(n_outputs,activation='tanh',name='p')(x)
    v = layers.Dense(1,name='v')(x)
    cnn_model = Model(inputs=_input,outputs=[v,p])
    losses = {
        "v": 'mean_squared_error',"p": keras.losses.BinaryCrossentropy()
    }
    cnn_model.compile(loss=losses,optimizer=optimizer)
    return cnn_model

cnn = get_cnn_model((31,n_outputs=16,seed=26)

这是示例数据集:

import numpy as np
import tensorflow as tf

v = 0.9
p = np.random.randn(16)
state = np.random.randn(31*31*9)

sample = tf.train.Example(
    features = tf.train.Features(
        feature = {
            'v': tf.train.Feature(float_List=tf.train.floatList(value=[v])),'p': tf.train.Feature(float_List=tf.train.floatList(value = p)),'s': tf.train.Feature(float_List=tf.train.floatList(value = state))
        }
    )
)

with tf.io.TFRecorDWriter('tf_record_data') as f:
    f.write(sample.SerializetoString())

这是我得到上述错误的训练过程:

def read_tfrecord(example):
    feature_desc = {
        'v': tf.io.FixedLenFeature([],tf.float32),'p': tf.io.VarLenFeature(tf.float32),'s': tf.io.VarLenFeature(tf.float32)
    }
    sample = tf.io.parse_single_example(example,feature_desc)
    x = tf.reshape(tf.sparse.to_dense(parsed['s']),(1,9))
    y = {'v':sample['v'],'p': tf.sparse.to_dense(sample['p'])}
    return x,y

ds = tf.data.TFRecordDataset(['tf_record_data'])
ds = ds.map(read_tfrecord)

cnn.fit(ds)

有趣的是,当我对数据集进行预测时,它确实有效:

import numpy as np
for serialized in tf.data.TFRecordDataset(['tf_record_data']):
    parsed = tf.io.parse_single_example(serialized,feature_desc)
    st= tf.sparse.to_dense(parsed['s'])
    t = tf.reshape(st,9))
    print(cnn.predict(t))

我该如何解决这个错误?

解决方法

我将数据记录的映射更改为以下内容:

def read_tfrecord(example):
    feature_desc = {
       'v': tf.io.FixedLenFeature([],tf.float32),'p': tf.io.VarLenFeature(tf.float32),'s': tf.io.VarLenFeature(tf.float32)
    }
    sample = tf.io.parse_single_example(example,feature_desc)
    x = tf.reshape(tf.sparse.to_dense(parsed['s']),(1,rows,cols,layers))
    p = tf.reshape(tf.sparse.to_dense(parsed['p']),16))
    v = tf.reshape(sample['v'],1))

    y = {'v':v,'p': p}
    return x,y

重塑输出解决了问题

大佬总结

以上是大佬教程为你收集整理的使用 tfdataset 训练 CNN全部内容,希望文章能够帮你解决使用 tfdataset 训练 CNN所遇到的程序开发问题。

如果觉得大佬教程网站内容还不错,欢迎将大佬教程推荐给程序员好友。

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
如您有任何意见或建议可联系处理。小编QQ:384754419,请注明来意。
标签: