3 Star 0 Fork 0

mirrors_lepy / ACORN

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
pruning_functions.py 1.25 KB
一键复制 编辑 原始数据 按行查看 历史
David Lindell 提交于 2021-08-09 12:44 . add files
import torch
import utils
def no_pruning(model, dataset, pruning_every=100):
return
def pruning_occupancy(model, dataset, threshold=-10):
model_input = dataset.get_eval_samples(1)
print("Pruning: loading data to cuda...")
tmp = {}
for key, value in model_input.items():
if isinstance(value, torch.Tensor):
tmp.update({key: value[None, ...].cuda()})
else:
tmp.update({key: value})
model_input = tmp
print("Pruning: evaluating occupancy...")
pred_occupancy = utils.process_batch_in_chunks(model_input, model)['model_out']['output']
pred_occupancy = torch.max(pred_occupancy, dim=-2).values.squeeze()
pred_occupancy_idx = model_input['coord_octant_idx'].squeeze()
print("Pruning: computing mean and freezing empty octants")
active_octants = dataset.octtree.get_active_octants()
frozen_octants = 0
for idx, octant in enumerate(active_octants):
max_prediction = torch.max(pred_occupancy[pred_occupancy_idx == idx])
if max_prediction < threshold and octant.err < 1e-3: # Prune if model is confident that everything is empty
octant.frozen = True
frozen_octants += 1
print(f"Pruning: Froze {frozen_octants} octants.")
dataset.synchronize()
1
https://gitee.com/mirrors_lepy/ACORN.git
git@gitee.com:mirrors_lepy/ACORN.git
mirrors_lepy
ACORN
ACORN
main

搜索帮助