dca.vectorize_method#
- dataclass_array.vectorize_method(fn: None = None, *, static_args: Set[str] | None = None) Callable[[dataclass_array.vectorization._FnT], dataclass_array.vectorization._FnT][source]#
- dataclass_array.vectorize_method(fn: dataclass_array.vectorization._FnT, *, static_args: Set[str] | None = None) dataclass_array.vectorization._FnT
Vectorize a
dca.DataclassArraymethod.Allow to implement method in
dca.DataclassArrayassuming shape == ().This is similar to jax.vmap but:
Only work on
dca.DataclassArraymethodsInstead of vectorizing a single axis,
@dca.vectorize_methodwill vectorize over *self.shape (not just self.shape[0]). This is like if vmap was applied to self.flatten()Axis with dimension 1 are brodcasted.
For example, with __matmul__(self, x: T) -> T:
() @ (*x,) -> (*x,) (b,) @ (b, *x) -> (b, *x) (b,) @ (1, *x) -> (b, *x) (1,) @ (b, *x) -> (b, *x) (b, h, w) @ (b, h, w, *x) -> (b, h, w, *x) (1, h, w) @ (b, 1, 1, *x) -> (b, h, w, *x)
Example:
class Point3d(dca.DataclassArray): p: f32['*shape 3'] @dca.vectorize_method def first_value(self): return self.p[0] point = Point3d(p=[ # 4 points batched together [10, 11, 12], [20, 21, 22], [30, 31, 32], [40, 41, 42], ]) point.first_value() == [10, 20, 30, 40] # First value of each points
- Parameters:
fn – DataclassArray method to decorate
static_args – If given, should be a set of the static argument names
- Returns:
Decorated function with vectorization applied to self.
- Return type:
fn