dca.DataclassArray#
- class dataclass_array.DataclassArray[source]#
Bases:
objectDataclass which behaves like an array.
Usage:
class Square(DataclassArray): pos: f32['*shape 2'] scale: f32['*shape'] name: str # Create 3 squares batched p = Square( pos=[[x0, y0], [x1, y1], [x2, y2]], scale=[scale0, scale1, scale2], name='my_square', ) p.shape == (3,) p.pos.shape == (3, 2) p[0] == Square(pos=[x0, y0], scale=scale0) p = p.reshape((3, 1)) # Reshape the inner-shape p.shape == (3, 1) p.pos.shape == (3, 1, 2) p.name == 'my_square'
DataclassArrayhas 2 types of fields:Array fields: Fields batched like numpy arrays, with reshape, slicing,… (pos and scale in the above example).
Static fields: Other non-numpy field. Are not modified by reshaping,… ( name in the above example). Static fields are also ignored in jax.tree_map.
DataclassArraydetect array fields if either:The typing annotation is a etils.array_types annotation (in which case shape/dtype are automatically infered from the typing annotation) Example: x: f32[…, 3]
The typing annotation is another
dca.DataclassArray(in which case my_dataclass.field.shape == my_dataclass.shape) Example: x: MyDataclassThe field is explicitly defined in dca.array_field, in which case the typing annotation is ignored. Example: x: Any = dca.field(shape=(), dtype=np.int64)
Field which do not satisfy any of the above conditions are static (including field annotated with field: np.ndarray or similar).
- property shape: Tuple[int, ...]#
Returns the batch shape common to all fields.
- property size: int#
Returns the number of elements.
- property ndim: int#
Returns the number of dimensions.
- reshape(shape: Union[Tuple[int, ...], str], **axes_length: int) dataclass_array.array_dataclass._DcT[source]#
Reshape the batch shape according to the pattern.
Supports both tuple and einops mode:
rays.reshape('b h w -> b (h w)') rays.reshape((128, -1))
- Parameters:
shape – Target shape. Can be string for einops support.
**axes_length – Any additional specifications for dimensions for einops support.
- Returns:
The dataclass array with the new shape
- broadcast_to(shape: Tuple[int, ...]) dataclass_array.array_dataclass._DcT[source]#
Broadcast the batch shape.
- map_field(fn: Callable[[etils.enp.array_types.typing.Array], etils.enp.array_types.typing.Array]) dataclass_array.array_dataclass._DcT[source]#
Apply a transformation on all arrays from the fields.
- as_np() dataclass_array.array_dataclass._DcT[source]#
Returns the instance as containing np.ndarray.
- as_jax() dataclass_array.array_dataclass._DcT[source]#
Returns the instance as containing jnp.ndarray.
- as_torch() dataclass_array.array_dataclass._DcT[source]#
Returns the instance as containing torch.Tensor.
- as_xnp(xnp: Any) dataclass_array.array_dataclass._DcT[source]#
Returns the instance as containing xnp.ndarray.
- property xnp: Any#
Returns the numpy module of the class (np, jnp, tnp).
- to(device, **kwargs) dataclass_array.array_dataclass._DcT[source]#
Move the dataclass array to the device.
- cpu(*args, **kwargs) dataclass_array.array_dataclass._DcT[source]#
Move the dataclass array to the CPU device.
- cuda(*args, **kwargs) dataclass_array.array_dataclass._DcT[source]#
Move the dataclass array to the CUDA device.
- classmethod tree_unflatten(metadata: _TreeMetadata, array_field_values: list[DcOrArray]) _DcT[source]#
jax.tree_utils support.
- assert_same_xnp(x: Union[etils.enp.array_types.typing.Array, dataclass_array.array_dataclass.DataclassArray]) None[source]#
Assert the given array is of the same type as the current object.