2.3K Star 8K Fork 4.2K

GVPMindSpore / mindspore

 / 详情

[CT][MS][CI] CombineMomentumWeight and FusedWeightScaleApplyMomentum error

DONE
Bug-Report
创建于  
2021-05-08 17:19
name about labels
Bug Report Use this template for reporting a bug kind/bug

Environment

  • Hardware Environment(Ascend/GPU/CPU):

Uncomment only one /device <> line, hit enter to put that in a new line, and remove leading whitespaces from that line:

/device gpu

  • Software Environment:
    -- MindSpore version (source or binary):
    -- Python version (e.g., Python 3.7.5):
    -- OS platform and distribution (e.g., Linux Ubuntu 16.04):
    -- GCC/Compiler version (if compiled from source):

Related testcase

test_ir_fusion_combine_momentum_weight
test_ir_fusion_fused_scale_momentum_decay

Steps to reproduce the issue

Describe the current behavior

def test_ir_fusion_fused_scale_momentum_decay():
        clear_files()
        context.set_context(save_graphs=True)
    
        epoch_size = 1
        batch_size = 1
        num_classes = 3
    
        input_np = np.random.uniform(0.0, 1.0,
                size=[batch_size, 3, 2, 2]).astype(np.float16)
        label_np = np.ones([batch_size, num_classes]).astype(np.float32)
        net = Net(3, num_classes)
        loss = SoftmaxCrossEntropyWithLogits(sparse=False)
        opt = Momentum(learning_rate=0.01, momentum=0.9,
                params=filter(lambda x: x.requires_grad, net.get_parameters()),
                weight_decay=1.5, loss_scale=1.5)
        lsm = FixedLossScaleManager(loss_scale=1.5, drop_overflow_update=False)
        net = amp.build_train_network(net, opt, loss,
                level="O3", loss_scale_manager=lsm)
        net.set_train()
        for epoch in range(0, epoch_size):
            net(Tensor(input_np), Tensor(label_np))
    
        result = find_files('hwopt*momentum_scale_fusion*ir',
                'FusedWeightScaleApplyMomentum')
>       assert result == '2'
E       AssertionError: assert '0' == '2'
E         - 0
E         + 2
def test_ir_fusion_combine_momentum_weight():
        clear_files()
        context.set_context(save_graphs=True)
    
        epoch_size = 1
        batch_size = 1
        num_classes = 3
    
        input_np = np.random.uniform(0.0, 1.0,
                size=[batch_size, 3, 2, 2]).astype(np.float16)
        label_np = np.ones([batch_size, num_classes]).astype(np.float32)
        net = Net2(3, num_classes)
        loss = SoftmaxCrossEntropyWithLogits(sparse=False)
        conv_params = list(filter(lambda x: 'conv' in x.name,
            net.trainable_params()))
        no_conv_params = list(filter(lambda x: 'conv' not in x.name,
            net.trainable_params()))
        group_params = [{'params': conv_params, 'weight_decay': 0.3},
                        {'params': no_conv_params, 'lr': 0.04},
                        {'order_params': net.trainable_params()}]
        opt = Momentum(group_params, learning_rate=0.03, momentum=0.9,
                loss_scale=1.3, weight_decay=0.7)
        lsm = FixedLossScaleManager(loss_scale=1.3, drop_overflow_update=False)
        net = amp.build_train_network(net, opt, loss, level="O3",
                loss_scale_manager=lsm)
        net.set_train()
        for epoch in range(0, epoch_size):
            net(Tensor(input_np), Tensor(label_np))
    
        result = find_files('hwopt*combine_momentum*ir',
                'CombineMomentumWeight')
>       assert result == '2'
E       AssertionError: assert '0' == '2'

Describe the expected behavior

pass

Related log / screenshot

Special notes for this issue

评论 (2)

吴天瑜 创建了Bug-Report
吴天瑜 关联仓库设置为MindSpore/mindspore
吴天瑜 负责人设置为chenweifeng
吴天瑜 里程碑设置为B-VM
吴天瑜 优先级设置为次要
吴天瑜 添加了
 
kind/bug
标签
吴天瑜 添加了device/gpu(已删除)标签
chenweifeng 添加协作者chenweifeng
chenweifeng 负责人chenweifeng 修改为zhangqinghua
zhangqinghua 负责人zhangqinghua 修改为huangbingjian
huangbingjian 计划开始日期设置为2021-05-10
huangbingjian 计划截止日期设置为2021-05-31
huangbingjian 任务状态TODO 修改为WIP
展开全部操作日志

Appearance & Root Cause
问题:IR图中未出现预期融合算子
原因:由于ME前端的修改,进入到后端的IR图与原设定的apply_momentum_weight_scale_fusion pass不匹配,Cast(input)-->Depend(Cast(input)),不能正常融合。

Fix Solution
解决方法:修改适配apply_momentum_weight_scale_fusion pass,使其能够处理Depend(Cast(input))的情况。
关联PR:!16210:update apply_momentum_weight_scale_fusion pass

huangbingjian 任务状态WIP 修改为VALIDATION
huangbingjian 负责人huangbingjian 修改为吴天瑜
huangbingjian 里程碑B-VM 修改为B-ComponentTest

pytest -s
test_ir_fusion_combine_momentum_weight
test_ir_fusion_fused_scale_momentum_decay

result pass

吴天瑜 任务状态VALIDATION 修改为DONE
吴天瑜 移除了device/gpu(已删除)标签

登录 后才可以发表评论

状态
负责人
项目
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
预计工期 (小时)
参与者(4)
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore

搜索帮助