Skip to content Skip to sidebar Skip to footer

How Can I Calculate FLOPs And Params Without 0 Weights Neurons Affected?

My Prune code is shown below, after running this, I will get a file named 'pruned_model.pth'. import torch from torch import nn import torch.nn.utils.prune as prune import torch.nn

Solution 1:

One thing you could do is to exclude the weights below a certain threshold from the FLOPs computation. To do so you would have to modify the flop counter functions.

I'll provide examples for the modification for fc and conv layers below.

def linear_flops_counter_hook(module, input, output):
    input = input[0]
    output_last_dim = output.shape[-1]  # pytorch checks dimensions, so here we don't care much
    # MODIFICATION HAPPENS HERE
    num_zero_weights = (module.weight.data.abs() < 1e-9).sum()
    zero_weights_factor = 1 - torch.true_divide(num_zero_weights, module.weight.data.numel())
    module.__flops__ += int(np.prod(input.shape) * output_last_dim) * zero_weights_factor.numpy()
    # MODIFICATION HAPPENS HERE
def conv_flops_counter_hook(conv_module, input, output):
    # Can have multiple inputs, getting the first one
    input = input[0]

    batch_size = input.shape[0]
    output_dims = list(output.shape[2:])

    kernel_dims = list(conv_module.kernel_size)
    in_channels = conv_module.in_channels
    out_channels = conv_module.out_channels
    groups = conv_module.groups

    filters_per_channel = out_channels // groups
    conv_per_position_flops = int(np.prod(kernel_dims)) * in_channels * filters_per_channel

    active_elements_count = batch_size * int(np.prod(output_dims))

    # MODIFICATION HAPPENS HERE
    num_zero_weights = (conv_module.weight.data.abs() < 1e-9).sum()
    zero_weights_factor = 1 - torch.true_divide(num_zero_weights, conv_module.weight.data.numel())
    overall_conv_flops = conv_per_position_flops * active_elements_count * zero_weights_factor.numpy()
    # MODIFICATION HAPPENS HERE
    
    bias_flops = 0

    if conv_module.bias is not None:

        bias_flops = out_channels * active_elements_count

    overall_flops = overall_conv_flops + bias_flops

    conv_module.__flops__ += int(overall_flops)

Note that I'm using 1e-9 as a threshold for a weight counting as zero.


Post a Comment for "How Can I Calculate FLOPs And Params Without 0 Weights Neurons Affected?"