Dive Into MindSpore – CSVDataset For Dataset LoadMindSpore精讲系列 – 数据集加载之CSVDataset本文开发环境Ubuntu 20.04Python 3.8MindSpore 1.7.0本文内容摘要先看API数据准备两种试错正确示例本文总结问题改进本文参考1. 先看API老传统,先看看官方文档:
参数解读:dataset_files – 数据集文件路径,可以单文件也可以是文件列表filed_delim – 字段分割符,默认为","column_defaults – 一个巨坑的参数,留待后面解读column_names – 字段名,用于后续数据字段的keynum_paraller_workers – 不再解释shuffle – 是否打乱数据,三种选择[False, Shuffle.GLOBAL, Shuffle.FILES]Shuffle.GLOBAL – 混洗文件和文件中的数据,默认Shuffle.FILES – 仅混洗文件num_shards – 不再解释shard_id – 不再解释cache – 不再解释2. 数据准备2.1 数据下载说明:数据下载地址:UCI Machine Learning Repository: Iris Data Set使用如下命令下载数据iris.data和iris.names到目标目录:mkdir iris && cd iris
wget -c https://archive.ics.uci.edu/m...
wget -c https://archive.ics.uci.edu/m...
备注:如果受系统限制,无法使用wget命令,可以考虑用浏览器下载,下载地址见说明。2.2 数据简介Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。更详细的介绍参见官方说明:5. Number of Instances: 150 (50 in each of three classes)
- Number of Attributes: 4 numeric, predictive attributes and the class
Attribute Information:
- sepal length in cm
- sepal width in cm
- petal length in cm
- petal width in cm
- class:
-- Iris Setosa
-- Iris Versicolour
-- Iris Virginica
- Missing Attribute Values: None
Summary Statistics:
Min Max Mean SD Class Correlation
sepal length: 4.3 7.9 5.84 0.83 0.7826
sepal width: 2.0 4.4 3.05 0.43 -0.4194
petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)
petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)
- Class Distribution: 33.3% for each of 3 classes.
2.3 数据分配这里对数据进行初步分配,分成训练集和测试集,分配比例为4:1。相关处理代码如下:from random import shuffle
def preprocess_iris_data(iris_data_file, train_file, test_file, header=True):
cls_0 = "Iris-setosa"
cls_1 = "Iris-versicolor"
cls_2 = "Iris-virginica"
cls_0_samples = []
cls_1_samples = []
cls_2_samples = []
with open(iris_data_file, "r", encoding="UTF8") as fp:
lines = fp.readlines()
for line in lines:
line = line.strip()
if not line:
continue
if cls_0 in line:
cls_0_samples.append(line)
continue
if cls_1 in line:
cls_1_samples.append(line)
continue
if cls_2 in line:
cls_2_samples.append(line)
shuffle(cls_0_samples)
shuffle(cls_1_samples)
shuffle(cls_2_samples)
print("number of class 0: {}".format(len(cls_0_samples)), flush=True)
print("number of class 1: {}".format(len(cls_1_samples)), flush=True)
print("number of class 2: {}".format(len(cls_2_samples)), flush=True)
train_samples = cls_0_samples[:40] + cls_1_samples[:40] + cls_2_samples[:40]
test_samples = cls_0_samples[40:] + cls_1_samples[40:] + cls_2_samples[40:]
header_content = "Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Classes"
with open(train_file, "w", encoding="UTF8") as fp:
if header:
fp.write("{}\n".format(header_content))
for sample in train_samples:
fp.write("{}\n".format(sample))
with open(test_file, "w", encoding="UTF8") as fp:
if header:
fp.write("{}\n".format(header_content))
for sample in test_samples:
fp.write("{}\n".format(sample))
def main():
iris_data_file = "{your_path}/iris/iris.data"
iris_train_file = "{your_path}/iris/iris_train.csv"
iris_test_file = "{your_path}/iris/iris_test.csv"
preprocess_iris_data(iris_data_file, iris_train_file, iris_test_file)
if name == "__main__":
main()
将以上代码保存到preprocess.py文件,使用如下命令运行:注意修改相关数据文件路径python3 preprocess.py
输出内容如下:number of class 0: 50
number of class 1: 50
number of class 2: 50
同时在目标目录生成iris_train.csv和iris_test.csv文件,目录内容如下所示:.
├── iris.data
├── iris.names
├── iris_test.csv
└── iris_train.csv
- 两种试错下面通过几种错误(带引号)用法,来初步认识一下CSVDataset。3.1 column_defaults是哪样首先,先来个简单加载,代码如下:为方便读者复现,这里将shuffle设置为False。from mindspore.dataset import CSVDataset
def dataset_load(data_files):
column_defaults = [float, float, float, float, str]
column_names = ["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width", "Classes"]
dataset = CSVDataset(
dataset_files=data_files,
field_delim=",",
column_defaults=column_defaults,
column_names=column_names,
num_samples=None,
shuffle=False)
data_iter = dataset.create_dict_iterator()
item = None
for data in data_iter:
item = data
break
print("====== sample ======\n{}".format(item), flush=True)
def main():
iris_train_file = "{your_path}/iris/iris_train.csv"
dataset_load(data_files=iris_train_file)
if name == "__main__":
main()
将以上代码保存到load.py文件,运行命令:注意修改数据文件路径python3 load.py
纳尼,报错,来看看报错内容:Traceback (most recent call last):
File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 107, in
main()
File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 103, in main
dataset_load(data_files=iris_train_file)
File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 75, in dataset_load
dataset = CSVDataset(
File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/engine/validators.py", line 1634, in new_method
raise TypeError("column type in column_defaults is invalid.")
TypeError: column type in column_defaults is invalid.
看看引发报错的源码,mindspore/dataset/engine/validators.py 中1634行,相关代码如下: # check column_defaults
column_defaults = param_dict.get('column_defaults')
if column_defaults is not None:
if not isinstance(column_defaults, list):
raise TypeError("column_defaults should be type of list.")
for item in column_defaults:
if not isinstance(item, (str, int, float)):
raise TypeError("column type in column_defaults is invalid.")
3.1.1 报错分析更多关于column_defaults参数的分析请参考6.1节。还记得官方参数说明吗,不记得没关系,这里再列出来。column_defaults (list, 可选) - 指定每个数据列的数据类型,有效的类型包括float、int或string。默认值:None,不指定。如果未指定该参数,则所有列的数据类型将被视为string。很显然,官方参数说明是数据类型,但是到mindspore/dataset/engine/validators.py代码里面,却检测的是数据实例类型。明确了这点,将代码:column_defaults = [float, float, float, float, str]
修改为:这里的数值取自iris.names文件,详情参考该文件。column_defaults = [5.84, 3.05, 3.76, 1.20, "Classes"]
再次运行代码,再次报错:WARNING: Logging before InitGoogleLogging() is written to STDERR
[ERROR] MD(13306,0x70000269b000,Python):2022-06-14-16:51:59.681.109 [mindspore/ccsrc/minddata/dataset/util/task_manager.cc:217] InterruptMaster] Task is terminated with err msg(more detail in info level log):Unexpected error. Invalid csv, csv file: /Users/kaierlong/Downloads/iris/iris_train.csv parse failed at line 1, type does not match.
Line of code : 506
File : /Users/jenkins/agent-working-dir/workspace/Compile_CPU_X86_MacOS_PY39/mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
Traceback (most recent call last):
File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 107, in
main()
File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 103, in main
dataset_load(data_files=iris_train_file)
File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 90, in dataset_load
for data in data_iter:
File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/engine/iterators.py", line 147, in next
data = self._get_next()
File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/engine/iterators.py", line 211, in _get_next
raise err
File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/engine/iterators.py", line 204, in _get_next
return {k: self._transform_tensor(t) for k, t in self._iterator.GetNextAsMap().items()}
RuntimeError: Unexpected error. Invalid csv, csv file: /Users/kaierlong/Downloads/iris/iris_train.csv parse failed at line 1, type does not match.
Line of code : 506
File : /Users/jenkins/agent-working-dir/workspace/Compile_CPU_X86_MacOS_PY39/mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
好了,这个错误我们到3.2部分进行分析。3.2 header要不要在3.1中,我们根据对报错源码的分析,明确了column_defaults的正确用法,但是依然存在一个错误。3.2.1 报错分析根据报错信息,发现是mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc中506行的报错,相关源码如下:Status CsvOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
CsvParser csv_parser(worker_id, jagged_rows_connector_.get(), field_delim_, column_default_list_, file);
RETURN_IF_NOT_OK(csv_parser.InitCsvParser());
csv_parser.SetStartOffset(start_offset);
csv_parser.SetEndOffset(end_offset);
auto realpath = FileUtils::GetRealPath(file.c_str());
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file path, " << file << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, " + file + " does not exist.");
}
std::ifstream ifs;
ifs.open(realpath.value(), std::ifstream::in);
if (!ifs.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + file + ", the file is damaged or permission denied.");
}
if (column_name_list_.empty()) {
std::string tmp;
getline(ifs, tmp);
}
csv_parser.Reset();
try {
while (ifs.good()) {
// when ifstream reaches the end of file, the function get() return std::char_traits::eof()
// which is a 32-bit -1, it's not equal to the 8-bit -1 on Euler OS. So instead of char, we use
// int to receive its return value.
int chr = ifs.get();
int err = csv_parser.ProcessMessage(chr);
if (err != 0) {
// if error code is -2, the returned error is interrupted
if (err == -2) return Status(kMDInterrupted);
RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse csv file: " + file + " at line " +
std::to_string(csv_parser.GetTotalRows() + 1) +
". Error message: " + csv_parser.GetErrorMessage());
}
}
} catch (std::invalid_argument &ia) {
std::string err_row = std::to_string(csv_parser.GetTotalRows() + 1);
RETURN_STATUS_UNEXPECTED("Invalid csv, csv file: " + file + " parse failed at line " + err_row +
", type does not match.");
} catch (std::out_of_range &oor) {
std::string err_row = std::to_string(csv_parser.GetTotalRows() + 1);
RETURN_STATUS_UNEXPECTED("Invalid csv, " + file + " parse failed at line " + err_row + " : value out of range.");
}
return Status::OK();
}
通过阅读上面的源码,发现源码中没有处理header行的代码,即默认所有行都是数据行。还记得2.3中数据分配部分的代码,我们写入了header信息,而CSVDataset并不提供处理header行的能力。现在根据报错分析定位,对2.3的数据分配代码进行修改,将代码preprocess_iris_data(iris_data_file, iris_train_file, iris_test_file)
修改为preprocess_iris_data(iris_data_file, iris_train_file, iris_test_file, header=False)
再次运行preprocess.py文件,生成新的数据。然后运行load.py文件(这里并不需要再改代码),输出内容如下:说明:为方便读者查看,这里对格式进行了人为处理,内容不变。这里已经能够正确读取数据,数据包含5个字段。数据字段名已经根据指定的column_names做了处理。====== sample ======
{'Sepal.Length': Tensor(shape=[], dtype=Float32, value= 5.5), 'Sepal.Width': Tensor(shape=[], dtype=Float32, value= 4.2), 'Petal.Length': Tensor(shape=[], dtype=Float32, value= 1.4), 'Petal.Width': Tensor(shape=[], dtype=Float32, value= 0.2),
'Classes': Tensor(shape=[], dtype=String, value= 'Iris-setosa')}
- 正确示例通过3中的两种试错,我们对CSVDataset有了初步认识,细心的读者可能会发现,3中依然有一个问题,那就是Classes字段没有进行数值化,下面我们就来介绍一种对其数值化的方法。源码如下:from mindspore.dataset import CSVDataset
from mindspore.dataset.text import Lookup, Vocab
def dataset_load(data_files):
column_defaults = [5.84, 3.05, 3.76, 1.20, "Classes"]
column_names = ["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width", "Classes"]
dataset = CSVDataset(
dataset_files=data_files,
field_delim=",",
column_defaults=column_defaults,
column_names=column_names,
num_samples=None,
shuffle=False)
cls_to_id_dict = {"Iris-setosa": 0, "Iris-versicolor": 1, "Iris-virginica": 2}
vocab = Vocab.from_dict(word_dict=cls_to_id_dict)
lookup = Lookup(vocab)
dataset = dataset.map(input_columns="Classes", operations=lookup)
data_iter = dataset.create_dict_iterator()
item = None
for data in data_iter:
item = data
break
print("====== sample ======\n{}".format(item), flush=True)
def main():
iris_train_file = "{your_path}/iris/iris_train.csv"
dataset_load(data_files=iris_train_file)
if name == "__main__":
main()
将以上代码保存到load.py文件,运行命令:注意修改数据文件路径python3 load.py
输出内容如下:说明:数据包含5个字段。Classes字段已经根据cls_to_id_dict = {"Iris-setosa": 0, "Iris-versicolor": 1, "Iris-virginica": 2}进行了数值化转换。数值化转换用到了mindspore.dataset.text的有关方法,读者可以自行查阅,后续会出相关的解读文章。====== sample ======
{'Sepal.Length': Tensor(shape=[], dtype=Float32, value= 5.5), 'Sepal.Width': Tensor(shape=[], dtype=Float32, value= 4.2), 'Petal.Length': Tensor(shape=[], dtype=Float32, value= 1.4), 'Petal.Width': Tensor(shape=[], dtype=Float32, value= 0.2),
'Classes': Tensor(shape=[], dtype=Int32, value= 0)}
后续:这里还存在其他字段的数据归一化,就留待读者去尝试了。数值化转换部分,也可以通过在数据分配部分增加代码来提前转换,读者也可以进行尝试。5. 本文总结本文对MindSpore中的CSVDataset数据集接口进行了探索和示例展示。通过错误试探,发现目前CSVDataset的文档和功能还相对较弱,只能说是可用。6. 问题改进6.1 column_defaults文档错误英文文档column_defaults (list, optional) – List of default values for the CSV field (default=None). Each item in the list is either a valid type (float, int, or string). If this is not provided, treats all columns as string type.
中文文档column_defaults (list, 可选) - 指定每个数据列的数据类型,有效的类型包括float、int或string。默认值:None,不指定。如果未指定该参数,则所有列的数据类型将被视为string。
这里中文翻译有误。其实英文API就有一定的歧义性,前面说了是每个字段的默认值(CSV文件中存在字段为空的情况),后面又说如果为空,则按照string类型处理,让人分不清究竟是数据类型实例还是数据类型。注意:其实这里既有数据类型实例的意思,又有数据类型的意思。当指定了column_defaults参数,则字段的默认值为column_defaults中相应位置的值,字段的类型为column_defaults相应位置值的数据类型。例如:某CSV文件包含三个字段,指定column_defaults为[2.0, 1, “x”],则读取该文件时,三个字段的类型会被识别为float、int、str,如果某行中第二个字段为空,则就用默认值1填充。6.2 不支持文件含有header如题6.3 不支持读取指定字段如题,API层面不显式支持,不过可以通过后续的数据处理来支持。7. 本文参考mindspore.dataset.CSVDatasetmindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc本文为原创文章,版权归作者所有,未经授权不得转载!