Debug features

class transformer_engine.debug.features.log_tensor_stats.LogTensorStats

This feature handles the logging of basic tensor statistics.

For a distributed setting, the auxiliary stats are computed for each node and gathered after the debug_api.step() call. Do not forget to invoke debug_api.step() at every step to log stats!

LogTensorStats supports micro-batching. If multiple forward/backward passes are invoked per debug_api.step(), then stats for all tensors except weights will be accumulated.

LogTensorStats can induce significant overhead. To mitigate this issue, logging stats with freq > 1 is recommended. If LogTensorStats is not used in a given step, the overhead is smaller. Moreover, if no other feature is used for the layer, the TE layer will run as fast as it would without debug_api initialized.

Parameters:
  • stats (List[str]) –

    list of statistics to log

    • min

    • max

    • mean

    • std

    • l1_norm

    • l2_norm

    • cur_amax – maximal absolute value of a tensor,

    • dynamic_range – equal to torch.log2(amax) - torch.log2(amin)

  • tensors/tensors_struct (List[str]) –

    list of tensors to log

    • activation

    • gradient

    • weight

    • output

    • wgrad

    • dgrad

  • freq (Optional[int], default = 1) – frequency of logging stats, stats will be logged every freq steps

  • start_step (Optional[int], default = None) – start step of logging stats

  • end_step (Optional[int], default = None) – end step of logging stats

  • start_end_list (Optional[list([int, int])], default = None) – non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step

Example

example_tensor_stat_collection:
    enabled: True
    layers:
        layer_name_regex_pattern: .*(fc1|self_attention).*
    transformer_engine:
        LogTensorStats:
            enabled: True
            tensors_struct:
                - tensor: activation
                  stats: [mean]
                  freq: 10
                  start_step: 5
                  end_step: 100
                - tensor: gradient
                  stats: [mean, max, min]
                  freq: 2
                  start_end_list: [[0, 20], [80, 100]]
                - tensor: weight
                  stats: [dynamic_range]
class transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats

Logs statistics of quantized tensors.

Supports computing statistics for current recipe, but also allows to see what would happend if different recipes were used for these tensors in current iteration. For example, during delayed-scaling training you may wish to track “current_scaling_underflows%” to measure the accuracy of the current scaling factors; note that this requires an extra cast and therefore adds overhead. Using a logging frequency (freq) greater than 1 is recommended in this case. Computing the stats matching the training recipe does not require an extra cast.

Statistics are identified by the pattern <recipe>_<stat> with optional _columnwise suffix (e.g. delayed_scaling_underflows% or mxfp8_scale_inv_min_columnwise). One can provide <stat> only, then the current training recipe is used.

Stats for delayed-scaling cannot be collected if delayed-scaling is not the current training recipe.

In distributed runs each rank first computes its local statistics; the values are gathered the next time debug_api.step() is called. Remember to call debug_api.step() every training step so the logs are flushed.

The feature is micro-batch aware: if several forward/backward passes occur between successive debug_api.step() calls, statistics are accumulated for all tensors except weights.

Collecting FP8 statistics is expensive. Choosing a larger freq reduces the overhead, and if the feature is skipped for a step the additional cost is minimal. When no other debug feature is active, the layer runs at normal Transformer Engine speed.

Parameters:
  • stats (List[str]) –

    Each stat is a string of the form <recipe>_<stat>, with an optional _columnwise suffix (i.e., <recipe>_<stat>_columnwise). If only <recipe> is omitted, the current training recipe is used. For mxfp8 and fp8_block_scaling _columnwise suffix can be provided. Then stat is computed on columnwise(transpose) version of the tensor, which can be numerically different from rowwise (non-transpose) tensors. “_columnwise” suffix is not supported for fp8_delayed_scaling and fp8_current_scaling.

    recipes:
    • fp8_delayed_scaling,

    • fp8_current_scaling,

    • mxfp8,

    • fp8_block_scaling,

    stats:
    • underflows% - percentage of non-zero elements of tensor clipped to 0 after quantization,

    • overflows% - percentage of elements of tensor that were clipped to the max/min value of the FP8 range - supported only for fp8_delayed_scaling,

    • scale_inv_min - minimum of the inverse of the scaling factors,

    • scale_inv_max - maximum of the inverse of the scaling factors,

    • mse - mean squared error of the quantized tensor and the original tensor = sum((quantized_tensor - original_tensor)**2) / num_elements,

  • tensors/tensors_struct (List[str]) –

    list of tensors to log
    • activation,

    • gradient,

    • weight,

  • freq (Optional[int], default = 1) – frequency of logging stats, stats will be logged every freq steps

  • start_step (Optional[int], default = None) – start step of logging stats

  • end_step (Optional[int], default = None) – end step of logging stats

  • start_end_list (Optional[list([int, int])], default = None) – non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step

Example

example_fp8_tensor_stat_collection:
    enabled: True
    layers:
        layer_types: [layernorm_linear]
    transformer_engine:
        LogFp8TensorStats:
            enabled: True
            tensors_struct:
            - tensor: activation
            stats: [mxfp8_underflows%]
            freq: 1
            - tensor: gradient
            stats: [underflows%]
            freq: 5
            start_step: 0
            end_step: 80
class transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM

GEMM operations are executed in higher precision, even when FP8 autocast is enabled.

Parameters:

gemms (List[str]) –

list of gemms to disable

  • fprop

  • dgrad

  • wgrad

Example

example_disable_fp8_gemm:
    enabled: True
    layers:
        layer_types: [fc1]
    transformer_engine:
        DisableFP8GEMM:
            enabled: True
            gemms: [dgrad, wgrad]
class transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer

Disables all FP8 GEMMs in the layer.

Example

example_disable_fp8_layer:
    enabled: True
layers:
    layer_types: [fc1]
transformer_engine:
    DisableFP8Layer:
        enabled: True
class transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling

Allows using per-tensor current scaling for the specific tensors.

Can be used only within DelayedScaling recipe autocast.

Parameters:
  • gemms/gemms_struct (List[str]) –

    list of gemms to enable per-tensor current scaling for

    • fprop

    • dgrad

    • wgrad

  • tensors/tensors_struct (List[str]) –

    list of tensors to enable per-tensor current scaling for

    • activation

    • gradient

    • weight

Example

example_per_tensor_scaling:
    enabled: True
    layers:
        layer_types: [transformer_layer.self_attn.layernorm_q]
    transformer_engine:
        PerTensorScaling:
            enabled: True
            gemms: [dgrad]
            tensors: [weight, activation]
class transformer_engine.debug.features.fake_quant.FakeQuant

Disables FP8 GEMM. Fake quantizes chosen tensors to FP8 - using per-tensor scaling factor, not delayed scaling - and runs high-precision GEMM.

../_images/fake_quant.svg

Fig 1: Comparison of FP8 FPROP GEMM with the same GEMM in BF16 with fake quantization of activation tensor. Green tensors have the same values, but different dtypes.

Parameters:
  • gemms/gemms_struct (List[str]) –

    list of gemms to fake quantize

    • fprop

    • dgrad

    • wgrad

  • tensors/tensors_struct (List[str]) –

    list of tensors to fake quantize

    • activation

    • gradient

    • weight

    • output

    • wgrad

    • dgrad

  • quant_format (str) –

    specifies the FP8 format to use:

    • FP8E5M2

    • FP8E4M3

Example

example_fake_quant_fp8:
    enabled: True
    layers:
        layer_types: [transformer_layer.layernorm_mlp.fc1]
    transformer_engine:
        FakeQuant:
            enabled: True
            quant_format: FP8E5M2
            gemms_struct:
            - gemm: fprop
                tensors: [activation, weight]
            - gemm: dgrad
                tensors: [gradient]