从数据的角度理解TensorFlow鸢尾花分类程序4
接上节,本文继续分析代码
# Fetch the data
(train_x, train_y), (test_x, test_y) = iris_data.load_data()
下载、导入和解析数据集从代码可以看出,load_data()在iris_data.py模块中,所以,打开iris_data.py文件
load_data()鸢尾花程序需要下列两个 .csv 文件中的数据:
http://download.tensorflow.org/data/iris_training.csv:其中包含训练集。
http://download.tensorflow.org/data/iris_test.csv:其中包含测试集。
训练集包含我们用于训练模型的样本;测试集包含我们用于评估训练后模型的效果的样本。
训练集和测试集起初是同一个数据集。然后,有人对样本进行拆分,大部分样本进入训练集,剩余部分进入测试集。向训练集添加样本通常会构建一个更好的模型;但是,向测试集添加更多样本则使我们能够更好地评估模型的效果。无论如何拆分,测试集中的样本都必须与训练集中的样本分隔开来。否则,您无法准确地确定模型的效果。
加入一个print(maybe_download()),可以看到
train_path的值是'C:\\Users\\tf\\.keras\\datasets\\iris_training.csv', test_path的值是'C:\\Users\\tf\\.keras\\datasets\\iris_test.csv',说明数据文件已经下载到~/.keras/datasets/文件夹了
train_path和train_path同上,加入一个print(load_data())语句,可以看到
(train_x, train_y), (test_x, test_y)大家若跟着本文做,则可以看到,文件iris_training.csv中的数据,被解析成:
训练数据集的特征值,储存在Dataframe类型的变量train_x里面;
训练数据集的特征标签,储存在Series类型的变量train_y里面;
机器学习中的记号约定,通常用x表示特征值,y表示特征标签。
同样,iris_test.csv中的数据,被解析成:
测试数据集的特征值,储存在Dataframe类型的变量test_x里面;
测试数据集的特征标签,储存在Series类型的变量test_y里面;
到此:
# Fetch the data
(train_x, train_y), (test_x, test_y) = iris_data.load_data()
的功能分析完毕,总结:iris_data.load_data()作用是:从指定URL路径下载两个文件iris_training.csv,iris_test.csv到本地~/.keras/datasets/文件夹中,然后把文件中的数据解析成训练数据集和测试数据集的特征值和标签,分别存在变量(train_x, train_y), (test_x, test_y)中,返回给主程序使用。
PS. 个人认为
test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL) 中的TEST_URL.split('/')[-1],有点儿烧脑,其作用是:把字符串TEST_URL按照'/'切分成字符串列表,然后取最后一个字符串,这本意就是从路径中取得文件名。
这种功能,我个人喜欢用清晰明了的os.path.basename(path)
我个人的编程倾向是 ,标准库里面实现的功能,尽量用标准库中的函数来实现,另起炉灶,会让阅读代码的人烧脑:)
熟悉并熟练运用Python标准库里面的函数,是一个专业Python程序员的义务和责任。