JavaWeb less 动态条形图 mAPI uitableview meteor methods nhibernate vbscript compilation Vanilla JS change事件 jquery解析json数据 大数据驾驶舱 录音棚设备一套多少钱 mysql合并结果集 mysql汉化包 flutter优缺点 kubernetes架构 java表达式 java中的队列 java集合转数组 java正则匹配数字 java的集合 java定义 网站后台模板 脚本之家 xp画图工具 千元以下最好的手机 倒计时计时器 烧饼修改器打不开 cubase下载 cad视口旋转 rpm卸载命令 cdr怎么填充颜色 铁血统帅 ps给图片加边框 街机roms下载 python保存文件 人马上单天赋
当前位置: 首页 > 学习教程  > 编程语言

tensorflow2.0在训练数据集的时候,fit和fit_generator的使用

2020/11/24 10:51:51 文章标签: 测试文章如有侵权请发送至邮箱809451989@qq.com投诉后文章立即删除

model.fit函数 fit(xNone, yNone, batch_sizeNone, epochs1, verbose1, callbacksNone,validation_split0.0, validation_dataNone, shuffleTrue, class_weightNone,sample_weightNone, initial_epoch0, steps_per_epochNone,validation_stepsNone, validation_batch_sizeNone…

model.fit函数

fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None,
    validation_split=0.0, validation_data=None, shuffle=True, class_weight=None,
    sample_weight=None, initial_epoch=0, steps_per_epoch=None,
    validation_steps=None, validation_batch_size=None, validation_freq=1,
    max_queue_size=10, workers=1, use_multiprocessing=False
)

为模型训练固定的批次(数据集上的迭代)。

参数:

参数作用
x输入数据.它可能是:Numpy数组(或类似数组的数组)或数组列表(如果模型具有多个输入)。TensorFlow张量或张量列表(如果模型具有多个输入)。如果模型已命名输入,则dict将输入名称映射到相应的数组/张量。一个tf.data数据集。应该返回(inputs, targets)或 的元组(inputs, targets, sample_weights)。生成器或keras.utils.Sequence返回(inputs, targets) 或(inputs, targets, sample_weights)。下面给出了迭代器类型(数据集,生成器,序列)的拆包行为的更详细描述。
y目标数据。像输入数据一样x,它可以是Numpy数组或TensorFlow张量。它应该与x(您不能有Numpy输入和张量目标,或者相反)保持一致。如果x是keras.utils.Sequence,y则不应该指定,生成器或实例(因为将从中获取目标x)。
batch_size整数或None。每个梯度更新的样本数。如果未指定,则默认为32。如果数据是以数据集,生成器或实例的形式(因为它们生成批次),则不要指定。 batch_sizebatch_sizekeras.utils.Sequence
epochs整数。训练模型的时期数。时期是整个x和所y 提供数据的迭代。请注意,在与结合, 应理解为“最后时期”。不会针对给出的多次迭代训练模型,而只是对到达索引的时期进行训练。 initial_epochepochsepochsepochs
verbose0、1或2。详细模式。0 =静音,1 =进度条,2 =每个时期一行。请注意,进度条在登录到文件时不是特别有用,因此,如果不以交互方式运行(例如,在生产环境中),建议使用verbose = 2。
callbackskeras.callbacks.Callback实例 列表。训练期间要应用的回调列表。请参阅tf.keras.callbacks。
validation_split在0到1之间浮动。将训练数据的分数用作验证数据。模型将分开训练数据的这一部分,不对其进行训练,并且将在每个时期结束时评估此数据的损失和任何模型度量。在改组之前,从x和中y提供的最后一个样本中选择验证数据。当x是数据集,生成器或keras.utils.Sequence实例时, 不支持此参数。
validation_data在每个时期结束时用于评估损失的数据和任何模型指标。该模型将不会根据此数据进行训练。因此,请注意以下事实:使用 或不受正则化层(如噪声和压降)影响的数据验证损失。 将覆盖。 可能: validation_splitvalidation_datavalidation_datavalidation_splitvalidation_data 1.(x_val, y_val)Numpy数组或张量的元组. 2(x_val, y_val, val_sample_weights)Numpy数组的元组 .数据集对于前两种情况,batch_size必须提供。对于最后一种情况,validation_steps可以提供。请注意,validation_data它并不支持xdict,generator或中支持的所有数据类型keras.utils.Sequence。
shuffle布尔值(是否在每个纪元之前改组训练数据)或str(用于“批处理”)。当x是生成器时,将忽略此参数。“批处理”是处理HDF5数据限制的特殊选项;它以批量大小的块洗牌。当没有任何效果是不是。 steps_per_epochNone
class_weight可选的字典映射类索引(整数)到权重(浮动)值,用于加权损失函数(仅在训练期间)。这可能有助于告诉模型“更多关注”来自代表性不足的类的样本。
sample_weight训练样本的可选Numpy权重数组,用于加权损失函数(仅在训练过程中)。您可以传递长度与输入样本相同的平坦(1D)Numpy数组(权重和样本之间的1:1映射),或者对于时间数据,可以传递带有shape的2D数组 以应用每个样品的每个时间步均具有不同的权重。如果是数据集,生成器或 实例,而将sample_weights作为的第三个元素,则不支持此参数。 (samples, sequence_length)xkeras.utils.Sequencex
initial_epoch整数。开始训练的时期(用于恢复以前的训练运行)。
steps_per_epoch整数或None。声明一个纪元完成并开始下一个纪元之前的总步数(一批样品)。使用输入张量(例如TensorFlow数据张量)进行训练时,默认None值等于数据集中的样本数除以批次大小;如果无法确定,则默认为1。如果x是 tf.data数据集,并且’steps_per_epoch’为None,则该纪元将运行直到输入数据集用尽。传递无限重复的数据集时,必须指定 参数。数组输入不支持此参数。 steps_per_epoch
validation_steps仅在提供时才相关,并且是数据集。在每个时期结束时执行验证时,在停止之前要绘制的步骤总数(样本批次)。如果“ validation_steps”为“无”,则验证将一直进行到数据集用尽。如果是无限重复的数据集,它将陷入无限循环。如果指定了“ validation_steps”,并且仅消耗了一部分数据集,则评估将在每个时期从数据集的开头开始。这样可以确保每次都使用相同的验证样本。 validation_datatf.datavalidation_data
validation_batch_size整数或None。每个验证批次的样品数量。如果未指定,则默认为。不要指定数据是数据集,生成器还是实例的形式(因为它们生成批处理)。 batch_sizevalidation_batch_sizekeras.utils.Sequence
validation_freq仅在提供验证数据时才相关。整数或实例(例如列表,元组等)。如果为整数,则指定在执行新的验证运行之前要运行多少个训练时期,例如,每2个时期运行一次验证。如果是容器,则指定要运行验证的时期,例如,在第一个,第二个和第十个时期的末尾运行验证。 collections_abc.Containervalidation_freq=2validation_freq=[1, 2, 10]
max_queue_size整数。keras.utils.Sequence 仅用于生成器或输入。生成器队列的最大大小。如果未指定,则默认为10。 max_queue_size
workers整数。keras.utils.Sequence仅用于生成器或输入。使用基于进程的线程时,要启动的最大进程数。如果未指定,workers 则默认为1。如果为0,将在主线程上执行生成器。
use_multiprocessing布尔值。keras.utils.Sequence仅用于生成器或 输入。如果为True,请使用基于进程的线程。如果未指定,则默认为 。请注意,由于此实现依赖于多处理,因此不应将不可拾取的参数传递给生成器,因为它们无法轻易传递给子进程。 use_multiprocessingFalse

类似于迭代器的输入的拆包行为:一种常见的模式是将tf.data.Dataset,generator或tf.keras.utils.Sequence传递给fitx参数,这实际上不仅会产生特征(x),而且会产生可选结果目标(y)和样本权重。Keras要求此类类似迭代器的输出必须明确。迭代器应返回长度为1、2或3的元组,其中可选的第二和第三元素将分别用于y和sample_weight。提供的任何其他类型将被包裹在一个元组的长度中,从而将所有内容有效地视为“ x”。发出命令时,它们仍应遵循顶级元组结构。例如({“x0”: x0, “x1”: x1}, y)。Keras不会尝试从单个字典的键中分离特征,目标和权重。值得注意的不受支持的数据类型是namedtuple。原因是它的行为类似于有序数据类型(元组)和映射数据类型(dict)。因此,给定形式的namedtuple: namedtuple(“example_tuple”, [“y”, “x”]) 在解释值时是否反转元素的顺序是不明确的。更糟糕的是以下形式的元组: namedtuple(“other_tuple”, [“x”, “y”, “z”]) 尚不清楚该元组是否打算解包为x,y和sample_weight或作为单个元素传递给x。结果,如果数据处理代码遇到一个命名元组,它将仅引发ValueError。(以及纠正该问题的说明。)

函数返回:
历史记录对象。它的History.history属性记录了连续时期的训练损失值和度量值,以及验证损失值和验证度量值(如果适用)。

fit_generator函数

fit_generator(
    generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None,
    validation_data=None, validation_steps=None, validation_freq=1,
    class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False,
    shuffle=True, initial_epoch=0
)

使模型适合Python生成器逐批生成的数据。

在使用tensorflow2.0.0版本的时候的,去运行了tensorflow2.2.0版本的yoloV4代码,导致报错出现:
在这里插入图片描述
在这里插入图片描述

TypeError: int() argument must be a string, a bytes-like object or a number, not ‘tuple’

原来是因为在2.0.0版本的时候Model.fit不支持生成器创建的数据集,因此会出现错误。
把Model.fit函数换成Model.fit_generator函数,则成功解决这个问题。

如果帮助到您,点赞论评走一波!!!!!
谢谢各位大佬!!!!


本文链接: http://www.dtmao.cc/news_show_400377.shtml

附件下载

相关教程

    暂无相关的数据...

共有条评论 网友评论

验证码: 看不清楚?