Java基本数据类型 GraphQL 双重检验锁 Kotlin elasticsearch layout devise laravel4 jScroll jquery去除空格 css获取最后一个元素 nginx默认端口号 vm虚拟化引擎 mysql或者条件 mysql临时表 河南普通话报名入口 python简易教程 python基础教程免费 random函数用法 java开发 java配置jdk linux系统教程 网页游戏开发入门 嵌入式linux驱动程序设计从入门到精通 行业软件下载 微信超级好友 电子书制作软件 pdf安装包官方下载 小米手环充电多久 jdk9下载 dnf95b套 免费图片文字识别软件 透视网格工具怎么取消 深入解析windows操作系统 oemdiy 动漫情侣头像一男一女 华为工具箱 微信昵称特殊符号 total同级生2下载 电脑微信官方下载
当前位置: 首页 > 学习教程  > 编程语言

tfRecord TypeError: only integer scalar arrays can be converted to a scalar index 错误解决办法

2021/2/13 17:09:28 文章标签: 测试文章如有侵权请发送至邮箱809451989@qq.com投诉后文章立即删除

从网上找到一个通过numpy生成tfrecord的代码,但是运行时报错,出现TypeError: only integer scalar arrays can be converted to a scalar index错误,原因是该记录为类型不匹配 需要从integer scalar arrays -> 单个int64数字原有问题代码…

 从网上找到一个通过numpy生成tfrecord的代码,但是运行时报错,出现TypeError: only integer scalar arrays can be converted to a scalar index错误,原因是该记录为类型不匹配 需要从integer scalar arrays  -> 单个int64数字

原有问题代码如下,注释部分为正确代码
"""
本程序演示了如何保存numpy array为TFRecords文件,并将其读取出来。
"""
import random

import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

def save_tfrecords(state_data, action_data, reward_data, dest_file):
    """
    保存numpy array到TFRecord文件中。
    这里输入了三个不同的numpy array来做演示,它们含有不同类型的元素。
    Args:
        state_data: 要保存到TFRecord文件的第1个numpy array,每一个 state_data[i] 是一个 numpy.ndarray(数组里的每个元素又是一个浮点
                    数),因此不能用 Int64List 或 FloatList 来存储,只能用 BytesList。
        action_data: 要保存到TFRecord文件的第2个numpy array,每一个 action_data[i] 是一个整数,使用 Int64List 来存储。
        reward_data: 要保存到TFRecord文件的第3个numpy array,每一个 reward_data[i] 是一个整数,使用 Int64List 来存储。
        dest_file: 输出文件的路径。
    Returns:
        不返回任何值
    """
    with tf.io.TFRecordWriter(dest_file) as writer:
        for i in range(len(state_data)):
            features = tf.train.Features(
                feature={
                    "state": tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[state_data[i].astype(np.float32).tobytes()])),
                    "action": tf.train.Feature(
                        int64_list=tf.train.Int64List(value=[action_data[i]])),
                    "reward": tf.train.Feature(
                        int64_list=tf.train.Int64List(value=[reward_data[i]]))
                    # "action": tf.train.Feature(
                    #     int64_list=tf.train.Int64List(value=action_data[i].astype(np.int))),
                    # "reward": tf.train.Feature(
                    #     int64_list=tf.train.Int64List(value=reward_data[i].astype(np.int)))
                }
            )
            tf_example = tf.train.Example(features=features)
            serialized = tf_example.SerializeToString()
            writer.write(serialized)



if __name__ == '__main__':
    buffer_s, buffer_a, buffer_r = [], [], []

    # 随机生成一些数据
    for i in range(3):
        state = [round(random.random() * 100, 2) for _ in range(0, 10)]  # 一个数组,里面有10个数,每个都是一个浮点数
        action = random.randrange(0, 2)  # 一个数,值为 0 或 1
        reward = random.randrange(0, 100)  # 一个数,值域 [0, 100)
        # 把生成的数分别添加到3个list中
        buffer_s.append(state)
        buffer_a.append(action)
        buffer_r.append(reward)

        # 查看生成的数据
    print(buffer_s)
    print(buffer_a)
    print(buffer_r)

    # 在水平方向把各个list堆叠起来,堆叠的结果:得到3个矩阵
    s_stacked = np.vstack(buffer_s)
    a_stacked = np.vstack(buffer_a)
    r_stacked = np.vstack(buffer_r)

    print(s_stacked.shape)  # (3, 10)
    print(a_stacked.shape)  # (3, 1)
    print(r_stacked.shape)  # (3, 1)


    print(s_stacked)
    print(a_stacked)
    print(r_stacked)


    print("data generate sucess!")

    # 写入TFRecord文件
    output_file = './data.tfrecord'  # 输出文件的路径
    save_tfrecords(s_stacked, a_stacked, r_stacked, output_file)

原始代码块:
action:为
[[0]
[1]
[1]]
通过切片操作获取其中的一个元素,这个元素也是list,所以不需要再强制转为list,通过 .astype(np.int) 转为具体类型


本文链接: http://www.dtmao.cc/news_show_700069.shtml

附件下载

相关教程

    暂无相关的数据...

共有条评论 网友评论

验证码: 看不清楚?