2.3K Star 8K Fork 4.2K

GVPMindSpore / mindspore

 / 详情

使用华为云昇腾芯片报warming:cannot find valid TBE kernel info, try to get aicpu kernel info

DONE
Bug-Report
创建于  
2021-10-21 23:59
[WARNING] DEVICE(93202,fffebaffd1e0,python):2021-10-14-12:35:29.593.936 [mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc:493] SelectKernelInfo] The node [kernel_graph_11:[CNode]31{[0]: ValueNode<PrimitivePy> Cast, [1]: [Parameter]32}] cannot find valid TBE kernel info, try to get aicpu kernel info
[WARNING] DEVICE(93202,fffebaffd1e0,python):2021-10-14-12:35:52.044.966 [mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc:493] SelectKernelInfo] The node [kernel_graph_15:[CNode]40{[0]: ValueNode<PrimitivePy> Cast, [1]: [Parameter]41}] cannot find valid TBE kernel info, try to get aicpu kernel info

模型报这种类似warming,连续报了非常多个之后才开始跑模型,然后跑得特别慢,而且跑到一半之后,模型停止。
输入图片说明
输入图片说明
time是一个batch的time 跑到第4个batch,模型不动了

测试之后发现,只使用最简单的model进行前向传播求模型输出,不进行反向传播和求损失,模型一样跑得特别慢,怀疑是model前向传播过程出现问题。在cpu上,前向传播速度非常快。以下是测试前向传播的代码。

train_iter = build_dataloader(train_data, config.batch_size, False)
net = Model(config)
print('-------train--------')
    for i, data in enumerate(train_iter.create_dict_iterator()):
        print(i)
        start_time = time()
        print(net(data['id'],data['ngram']))
        end_time = time()
        print('time:' + str(end_time - start_time))

model定义:

class Model(nn.Cell):
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config
        self.c_seed = 1.0
        self.manifold = PoincareBall()
        self.concat = ops.Concat(1)
        if config.embedding_pretrained is not None:
            pass
        else:
            # emb = Tensor(xavier_normal(Tensor(np.random.random((config.n_vocab, config.embed)))), dtype=ms.float32)
            # test data
            emb = Tensor(xavier_normal(Tensor(np.random.random((500, config.embed)))), dtype=ms.float32)
            emb[0] = emb[0].fill(0)
            self.embedding = Parameter(emb, requires_grad=True, )
            self.embedding.manifold = self.manifold
            self.embedding.c = self.c_seed

        # emb_wordngram = Tensor(xavier_normal(Tensor(np.random.random((config.bucket, config.embed)))), dtype=ms.float32)
        # test data
        emb_wordngram = Tensor(xavier_normal(Tensor(np.random.random((500, config.embed)))), dtype=ms.float32)
        emb_wordngram[0] = emb_wordngram[0].fill(0)
        self.embedding_wordngram = Parameter(emb_wordngram, requires_grad=True, )
        self.embedding.manifold = self.manifold
        self.embedding.c = self.c_seed
        # 这里的drop传入参数意义与pytorch相反,在原文中传入default为0.0,也就是说在此处需要default为1.0
        self.dropout = nn.Dropout(config.dropout)
        self.hyperLinear = MobiusLinear(self.manifold, config.embed,
                                        config.num_classes, c=self.c_seed)

    def construct(self, x_1, x_2):
        out_word = self.embedding[x_1]
        out_wordngram = self.embedding_wordngram[x_2]
        out = self.concat((out_word, out_wordngram))
        out = self.manifold.einstein_midpoint(out, c=self.c_seed)
        out = self.hyperLinear(out)
        out = self.manifold.logmap0(out, self.c_seed)

        return out

评论 (11)

WangWeihang 创建了Bug-Report

Please assign maintainer to check this issue.
请为这个issue分配处理人, @fangwenyi @chengxiaoli

Please add labels (comp or sig),also you can visit "https://gitee.com/mindspore/community/blob/master/sigs/dx/docs/labels.md" to find more.
为了让问题更快得到响应,请您为该issue打上 组件(comp)或兴趣组(sig) 标签,打上标签的问题可以直接推送给责任人进行处理。更多的标签可以查看
https://gitee.com/mindspore/community/blob/master/sigs/dx/docs/labels.md
以组件问题为例,如果你发现问题是data组件造成的,你可以这样评论:
//comp/data
当然你也可以向data SIG组求助,可以这样写:
//comp/data
//sig/data
如果是一个简单的问题,你可以留给刚进入社区的小伙伴来回答,这时候你可以这样写:
//good-first-issue
恭喜你,你已经学会了使用命令来打标签,接下来就在下面的评论里打上标签吧!

i-robot 添加了
 
kind/bug
标签

hello, @WangWeihang @WangWeihang , we suggest you add some labels like:
你好, @WangWeihang @WangWeihang , 建议您为这个issue打上标签:
//comp/device

WangWeihang 修改了描述
WangWeihang 修改了描述

问题已经收到,会尽快分析答复,请耐心等待下

chengxiaoli 任务状态TODO 修改为ACCEPTED
chengxiaoli 负责人设置为chengxiaoli
chengxiaoli 优先级设置为主要
chengxiaoli 添加了
 
mindspore-assistant
标签

https://www.mindspore.cn/mindinsight/docs/zh-CN/r1.5/performance_profiling_ascend.html

您可以参考上述profiling工具,进行一下算子耗时分析吗?

导出结果中应该能看到每个算子的耗时,根据描述,有可能是某些算子因为没有对应TBE的支持,而选择了AICPU的算子,从而拖慢了整体计算速度。
看模型结构里,有一个self.manifold = PoincareBall(),不知道是什么结构,报不支持的是cast算子,有可能存在一些类似float64相关的cast操作,由于Ascend芯片的支持问题不能在加速核上运行,只能选择aicpu算子。

具体情况的话,最好参考profiling的结果进行分析,如果有进一步的信息,可以在issue里进行补充。

PoincareBall的代码,是一个用于计算的类。

class PoincareBall(object):

    def __init__(self, ):
        super(PoincareBall, self).__init__()
        self.name = 'PoincareBall'
        self.min_norm = 1e-15
        self.eps = {mindspore.float32: 4e-3, mindspore.float64: 1e-5}

    def sqdist(self, p1, p2, c):
        sqrt_c = c ** 0.5
        dist_c = artanh(
            np.norm(sqrt_c * self.mobius_add(-p1, p2, c, dim=-1), axis=-1)
        )
        dist = dist_c * 2 / sqrt_c
        return dist ** 2

    def _lambda_x(self, x, c):
        Pow = ops.Pow()
        Rsum = ops.ReduceSum(keep_dims=True)
        x_sqnorm = Rsum(Pow(x, 2), -1)
        return 2 / np.clip((1. - c * x_sqnorm), self.min_norm, np.inf)

    def egrad2rgrad(self, p, dp, c):
        lambda_p = self._lambda_x(p, c)
        Pow = ops.Pow()
        dp /= Pow(lambda_p, 2)
        return dp

    def proj(self, x, c):
        norm = np.clip(np.norm(x, axis=-1, keepdims=True), self.min_norm, np.inf)
        maxnorm = (1 - self.eps[x.dtype]) / (c ** 0.5)
        projected = x / norm * maxnorm
        cond = norm > maxnorm
        return np.where(cond, projected, x)

    def proj_tan(self, u):
        return u

    def proj_tan0(self, u):
        return u

    def expmap(self, u, p, c):
        sqrt_c = c ** 0.5
        u_norm = np.clip(np.norm(u, axis=-1, keepdims=True), self.min_norm, np.inf)
        second_term = (
                clamp_tanh(sqrt_c / 2 * self._lambda_x(p, c) * u_norm)
                * u
                / (sqrt_c * u_norm)
        )
        gamma_1 = self.mobius_add(p, second_term, c)
        return gamma_1

    def logmap(self, p1, p2, c):
        sub = self.mobius_add(-p1, p2, c)
        sub_norm = np.clip(np.norm(sub, axis=-1, keepdims=True), self.min_norm, np.inf)
        lam = self._lambda_x(p1, c)
        sqrt_c = c ** 0.5
        return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm

    def expmap0(self, u, c):
        sqrt_c = c ** 0.5
        u_norm = np.clip(np.norm(u, axis=-1, keepdims=True), self.min_norm, np.inf)
        gamma_1 = clamp_tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
        return gamma_1

    def logmap0(self, p, c):
        sqrt_c = c ** 0.5
        p_norm = np.clip(np.norm(p, axis=-1, keepdims=True), self.min_norm, np.inf)
        scale = 1. / sqrt_c * artanh(sqrt_c * p_norm) / p_norm
        return scale * p

    def mobius_add(self, x, y, c, dim=-1):
        Pow = ops.Pow()
        Rsum = ops.ReduceSum(keep_dims=True)
        x2 = Rsum(Pow(x, 2), dim)
        y2 = Rsum(Pow(y, 2), dim)
        xy = Rsum(x * y, dim)
        left = (1 + 2 * c * xy + c * y2) * x
        right = np.dot((1 - c * x2),y)
        num = left + right
        denom = 1 + 2 * c * xy + c ** 2 * x2 * y2
        return num / np.clip(denom, self.min_norm, np.inf)

    def mobius_matvec(self, m, x, c):
        sqrt_c = c ** 0.5
        x_norm = np.clip(np.norm(x, axis=-1, keepdims=True), self.min_norm, np.inf)
        mx = np.matmul(x, m.swapaxes(-1, -2))
        mx_norm = np.clip(np.norm(mx, axis=-1, keepdims=True), self.min_norm, np.inf)
        t1 = artanh(sqrt_c * x_norm)
        t2 = clamp_tanh(mx_norm / x_norm * t1)
        res_c = t2 * mx / (mx_norm * sqrt_c)

        t = (mx == 0).astype(mindspore.uint8).asnumpy()
        cond = numpy.prod(t, -1, keepdims=True)
        zeros = ops.Zeros()
        res_0 = zeros(1, res_c.dtype)
        res = np.where(Tensor(cond), res_0, res_c)
        return res

    def init_weights(self, w, irange=1e-5):
        shape = w.shape
        w = numpy.random.uniform(-irange, irange, shape)
        w = Tensor(w).set_dtype(mstype.float64)
        return w

    def _gyration(self, u, v, w, c, dim: int = -1):
        Pow = ops.Pow()
        Rsum = ops.ReduceSum(keep_dims=True)
        u2 = Rsum(Pow(u, 2), dim)
        v2 = Rsum(Pow(v, 2), dim)
        uv = Rsum(u * v, dim)
        uw = Rsum(u * w, dim)
        vw = Rsum(v * w, dim)
        c2 = c ** 2
        a = -c2 * uw * v2 + c * vw + 2 * c2 * uv * vw
        b = -c2 * vw * u2 - c * uw
        d = 1 + 2 * c * uv + c2 * u2 * v2
        return w + 2 * (a * u + b * v) / np.clip(d, self.min_norm, np.inf)

    def inner(self, x, c, u, v=None, keepdim=False):
        if v is None:
            v = u
        lambda_x = self._lambda_x(x, c)
        Rsum = ops.ReduceSum(keep_dims=keepdim)
        return Rsum(lambda_x ** 2 * (u * v), -1)

    def ptransp(self, x, y, u, c):
        lambda_x = self._lambda_x(x, c)
        lambda_y = self._lambda_x(y, c)
        return self._gyration(y, -x, u, c) * lambda_x / lambda_y

    def ptransp_(self, x, y, u, c):
        lambda_x = self._lambda_x(x, c)
        lambda_y = self._lambda_x(y, c)
        return self._gyration(y, -x, u, c) * lambda_x / lambda_y

    def to_hyperboloid(self, x, c):
        K = 1.0 / c
        sqrtK = K ** 0.5
        sqnorm = np.norm(x, axis=-1, keepdims=True) ** 2
        cat = ops.Concat(-1)
        sqrtK * cat(K + sqnorm, sqrtK * x) / (K - sqnorm)
        return sqrtK * cat((K + sqnorm, sqrtK * x)) / (K - sqnorm)

    def klein_constraint(self, x):
        last_dim_val = x.shape[-1]
        norm = np.norm(x, axis=-1).reshape(-1, 1)
        maxnorm = (1 - self.eps[x.dtype])
        cond = norm > maxnorm
        x_reshape = x.reshape(-1, last_dim_val)
        projected = x_reshape / (norm + self.min_norm) * maxnorm
        x_reshape = np.where(cond, projected, x_reshape)
        x = x_reshape.reshape(x.shape)
        return x

    def to_klein(self, x, c):
        Rsum = ops.ReduceSum(keep_dims=True)
        x_2 = Rsum(x * x, -1)
        x_klein = 2 * x / (1.0 + x_2)
        x_klein = self.klein_constraint(x_klein)
        return x_klein

    def klein_to_poincare(self, x, c):
        sqrt = ops.Sqrt()
        Rsum = ops.ReduceSum(keep_dims=True)
        x_poincare = x / (1.0 + sqrt(1.0 - Rsum(x * x, -1)))
        # print(x_poincare)
        x_poincare = self.proj(x_poincare, c)
        # print(x_poincare)
        return x_poincare

    def lorentz_factors(self, x):
        x_norm = np.norm(x, axis=-1)
        return 1.0 / (1.0 - x_norm ** 2 + self.min_norm)

    def einstein_midpoint(self, x, c):
        expand_dims = ops.ExpandDims()
        Rsum = ops.ReduceSum(keep_dims=True)
        x = self.to_klein(x, c)
        x_lorentz = self.lorentz_factors(x)
        x_norm = np.norm(x, axis=-1)
        # deal with pad value
        x_lorentz = (1.0 - (x_norm == 0)) * x_lorentz
        x_lorentz_sum = Rsum(x_lorentz, -1)
        x_lorentz_expand = expand_dims(x_lorentz, -1)
        x_midpoint = Rsum(x_lorentz_expand * x, 1).reshape(x.shape[0],-1) / x_lorentz_sum
        x_midpoint = self.klein_constraint(x_midpoint)
        x_p = self.klein_to_poincare(x_midpoint, c)
        return x_p

这个看代码不太直接,有profiling的结果吗?

profiling是不是只能使用graph模式,我们的模型在graph模式下会报错 还在调

profiling是不是只能使用graph模式,我们的模型在graph模式下会报错 还在调

@WangWeihang pynative模式支持还不完善,但是也可以运行一下看看。如果是1.5的最新版本mindspore的话,pynative模式下进行profiling,运行目录下会有一个类似pynative_forward_profiling_data.csv的文件,里面有各个算子的时间。
输入图片说明

fangwenyi 任务状态ACCEPTED 修改为WIP
fangwenyi 负责人chengxiaoli 修改为chenhaozhe
fangwenyi 添加协作者chengxiaoli
fangwenyi 里程碑设置为B-SIG-ModelZoo
fangwenyi 添加了DFX/start-analysis(已删除)标签
chenhaozhe 添加了
 
stat/wait-response
标签
chenhaozhe 移除了
 
stat/wait-response
标签
oacjiewen 添加了
 
v1.5.1
标签

请使用最新版本验证,此ISSUE先关闭,如有需要请重新提单,或者自行修改ISSUE状态,谢谢

fangwenyi 任务状态WIP 修改为VALIDATION
fangwenyi 任务状态VALIDATION 修改为DONE
fangwenyi 移除了DFX/start-analysis(已删除)标签

登录 后才可以发表评论

状态
负责人
项目
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
预计工期 (小时)
参与者(5)
7541106 cheng xiaoli 1706167756 5644189 c 34 1586403335
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore

搜索帮助