| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Callable |
| |
|
| | import torch |
| | from torch import Tensor |
| |
|
| | from flow_matching.utils.manifolds import Manifold |
| |
|
| |
|
| | def geodesic( |
| | manifold: Manifold, start_point: Tensor, end_point: Tensor |
| | ) -> Callable[[Tensor], Tensor]: |
| | """Generate parameterized function for geodesic curve. |
| | |
| | Args: |
| | manifold (Manifold): the manifold to compute geodesic on. |
| | start_point (Tensor): point on the manifold at :math:`t=0`. |
| | end_point (Tensor): point on the manifold at :math:`t=1`. |
| | |
| | Returns: |
| | Callable[[Tensor], Tensor]: a function that takes in :math:`t` and outputs the geodesic at time :math:`t`. |
| | """ |
| |
|
| | shooting_tangent_vec = manifold.logmap(start_point, end_point) |
| |
|
| | def path(t: Tensor) -> Tensor: |
| | """Generate parameterized function for geodesic curve. |
| | |
| | Args: |
| | t (Tensor): Times at which to compute points of the geodesics. |
| | |
| | Returns: |
| | Tensor: geodesic path evaluated at time t. |
| | """ |
| | tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec) |
| | points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs) |
| |
|
| | return points_at_time_t |
| |
|
| | return path |
| |
|