dca.vectorize_method#

dataclass_array.vectorize_method(fn: None = None, *, static_args: Set[str] | None = None) Callable[[dataclass_array.vectorization._FnT], dataclass_array.vectorization._FnT][source]#
dataclass_array.vectorize_method(fn: dataclass_array.vectorization._FnT, *, static_args: Set[str] | None = None) dataclass_array.vectorization._FnT

Vectorize a dca.DataclassArray method.

Allow to implement method in dca.DataclassArray assuming shape == ().

This is similar to jax.vmap but:

  • Only work on dca.DataclassArray methods

  • Instead of vectorizing a single axis, @dca.vectorize_method will vectorize over *self.shape (not just self.shape[0]). This is like if vmap was applied to self.flatten()

  • Axis with dimension 1 are brodcasted.

For example, with __matmul__(self, x: T) -> T:

() @ (*x,) -> (*x,)
(b,) @ (b, *x) -> (b, *x)
(b,) @ (1, *x) -> (b, *x)
(1,) @ (b, *x) -> (b, *x)
(b, h, w) @ (b, h, w, *x) -> (b, h, w, *x)
(1, h, w) @ (b, 1, 1, *x) -> (b, h, w, *x)

Example:

class Point3d(dca.DataclassArray):
  p: f32['*shape 3']

  @dca.vectorize_method
  def first_value(self):
    return self.p[0]

point = Point3d(p=[  # 4 points batched together
    [10, 11, 12],
    [20, 21, 22],
    [30, 31, 32],
    [40, 41, 42],
])
point.first_value() == [10, 20, 30, 40]  # First value of each points
Parameters:
  • fn – DataclassArray method to decorate

  • static_args – If given, should be a set of the static argument names

Returns:

Decorated function with vectorization applied to self.

Return type:

fn