mm.math.map_fn

math.map_fn(elements)

Transforms elems by applying fn to each element unstacked on axis 0.

Parameters:
  • fn (func) – The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as elems.

  • elements (Tensor) – A tensor or (possibly nested) sequence of tensors, each of which will be unstacked along their first dimension. func will be applied to the nested sequence of the resulting slices.

Returns:

applied func on elements

Return type:

Tensor