dca.DataclassArray#

class dataclass_array.DataclassArray[source]#

Bases: object

Dataclass which behaves like an array.

Usage:

class Square(DataclassArray):
  pos: f32['*shape 2']
  scale: f32['*shape']
  name: str

# Create 3 squares batched
p = Square(
    pos=[[x0, y0], [x1, y1], [x2, y2]],
    scale=[scale0, scale1, scale2],
    name='my_square',
)
p.shape == (3,)
p.pos.shape == (3, 2)
p[0] == Square(pos=[x0, y0], scale=scale0)

p = p.reshape((3, 1))  # Reshape the inner-shape
p.shape == (3, 1)
p.pos.shape == (3, 1, 2)

p.name == 'my_square'

DataclassArray has 2 types of fields:

  • Array fields: Fields batched like numpy arrays, with reshape, slicing,… (pos and scale in the above example).

  • Static fields: Other non-numpy field. Are not modified by reshaping,… ( name in the above example). Static fields are also ignored in jax.tree_map.

DataclassArray detect array fields if either:

  • The typing annotation is a etils.array_types annotation (in which case shape/dtype are automatically infered from the typing annotation) Example: x: f32[…, 3]

  • The typing annotation is another dca.DataclassArray (in which case my_dataclass.field.shape == my_dataclass.shape) Example: x: MyDataclass

  • The field is explicitly defined in dca.array_field, in which case the typing annotation is ignored. Example: x: Any = dca.field(shape=(), dtype=np.int64)

Field which do not satisfy any of the above conditions are static (including field annotated with field: np.ndarray or similar).

property shape: Tuple[int, ...]#

Returns the batch shape common to all fields.

property size: int#

Returns the number of elements.

property ndim: int#

Returns the number of dimensions.

reshape(shape: Union[Tuple[int, ...], str], **axes_length: int) dataclass_array.array_dataclass._DcT[source]#

Reshape the batch shape according to the pattern.

Supports both tuple and einops mode:

rays.reshape('b h w -> b (h w)')
rays.reshape((128, -1))
Parameters:
  • shape – Target shape. Can be string for einops support.

  • **axes_length – Any additional specifications for dimensions for einops support.

Returns:

The dataclass array with the new shape

flatten() dataclass_array.array_dataclass._DcT[source]#

Flatten the batch shape.

broadcast_to(shape: Tuple[int, ...]) dataclass_array.array_dataclass._DcT[source]#

Broadcast the batch shape.

map_field(fn: Callable[[etils.enp.array_types.typing.Array], etils.enp.array_types.typing.Array]) dataclass_array.array_dataclass._DcT[source]#

Apply a transformation on all arrays from the fields.

replace(**kwargs: Any) dataclass_array.array_dataclass._DcT[source]#

Alias for dataclasses.replace.

as_np() dataclass_array.array_dataclass._DcT[source]#

Returns the instance as containing np.ndarray.

as_jax() dataclass_array.array_dataclass._DcT[source]#

Returns the instance as containing jnp.ndarray.

as_tf() dataclass_array.array_dataclass._DcT[source]#

Returns the instance as containing tf.Tensor.

as_torch() dataclass_array.array_dataclass._DcT[source]#

Returns the instance as containing torch.Tensor.

as_xnp(xnp: Any) dataclass_array.array_dataclass._DcT[source]#

Returns the instance as containing xnp.ndarray.

property xnp: Any#

Returns the numpy module of the class (np, jnp, tnp).

to(device, **kwargs) dataclass_array.array_dataclass._DcT[source]#

Move the dataclass array to the device.

cpu(*args, **kwargs) dataclass_array.array_dataclass._DcT[source]#

Move the dataclass array to the CPU device.

cuda(*args, **kwargs) dataclass_array.array_dataclass._DcT[source]#

Move the dataclass array to the CUDA device.

tree_flatten() tuple[tuple[DcOrArray, ...], _TreeMetadata][source]#

jax.tree_utils support.

classmethod tree_unflatten(metadata: _TreeMetadata, array_field_values: list[DcOrArray]) _DcT[source]#

jax.tree_utils support.

assert_same_xnp(x: Union[etils.enp.array_types.typing.Array, dataclass_array.array_dataclass.DataclassArray]) None[source]#

Assert the given array is of the same type as the current object.