MindSpore报错 For 'CellList'

发布时间:2023-03-29 15:30

1 报错描述
1.1 系统环境
Hardware Environment(Ascend/GPU/CPU): Ascend
Software Environment:
– MindSpore version (source or binary): 1.8.0
– Python version (e.g., Python 3.7.5): 3.7.6
– OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 4.15.0-74-generic
– GCC/Compiler version (if compiled from source):

1.2 基本信息
1.2.1 脚本
训练脚本是通过构建CellList的单算子网络,实现cell列表容器。脚本如下:

01 class ListNoneExample(nn.Cell):
02 def __init__(self):
03 super(ListNoneExample, self).__init__()
04 self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()])
05
06 def construct(self, x):
07 output = []
08 for op in self.lst:
09 output.append(op(x))
10 return output
11
12 input = Tensor(np.random.normal(0, 2, (2, 1)).astype(np.float32))
13 example = ListNoneExample()
14 output = example(input)
15 print("Output:", output)
1.2.2 报错
这里报错信息如下:

Traceback (most recent call last):
File "C:/Users/l30026544/PycharmProjects/q2_map/new/I3OGVW.py", line 31, in

example = ListNoneExample()

File "C:/Users/l30026544/PycharmProjects/q2_map/new/I3OGVW.py", line 19, in init

self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()])

File "C:\Users\l30026544\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py", line 310, in init

self.extend(args[0])

File "C:\Users\l30026544\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py", line 405, in extend

if _valid_cell(cell, cls_name):

File "C:\Users\l30026544\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py", line 39, in _valid_cell

raise TypeError(f'{msg_prefix} each cell should be subclass of Cell, but got {type(cell).__name__}.')

TypeError: For 'CellList', each cell should be subclass of Cell, but got NoneType.

原因分析

我们看报错信息,在TypeError中,写到For ‘CellList’, each cell should be subclass of Cell, but got NoneType.
,意思是对于CellList这个算子, 传入的每一个cell都因该是nn.Cell的子类, 但是得到了None类型。检查网络中初始化CellList的行为第4行, 发现传入了一个None, 因此报错。为了解决这个问题, 只需把这里的None换成一个继承于基类Cell类的对象, 就能实现相同的功能。

2 解决方法
基于上面已知的原因,很容易做出如下修改:

01 class NoneCell(nn.Cell):
02 def __init__(self):
03 super(NoneCell, self).__init__()
04
05 def construct(self, x):
06 return x
07
08 class ListNoneExample(nn.Cell):
09 def __init__(self):
10 super(ListNoneExample, self).__init__()
11 self.lst = nn.CellList([nn.ReLU(), NoneCell(), nn.ReLU()])
12
13 def construct(self, x):
14 output = []
15 for op in self.lst:
16 output.append(op(x))
17 return output
18
19 input = Tensor(np.random.normal(0, 2, (2, 1)).astype(np.float32))
20 example = ListNoneExample()
21 output = example(input)
22 print("Output:", output)
此时执行成功,输出如下:

Output: (Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
[0.00000000e+000]]), Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
[-2.74355006e+000]]), Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
[0.00000000e+000]]))
3 总结
定位报错问题的步骤:

1、找到报错的用户代码行:self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()]);

2、 根据日志报错信息中的关键字,缩小分析问题的范围each cell should be subclass of Cell, but got NoneType ;

3、需要重点关注变量定义、初始化的正确性。

4 参考文档
4.1 CellList算子API接口

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号