| | """Beamformer module.""" |
| | from typing import Sequence, Tuple, Union |
| |
|
| | import torch |
| | from packaging.version import parse as V |
| | from torch_complex import functional as FC |
| | from torch_complex.tensor import ComplexTensor |
| |
|
| | EPS = torch.finfo(torch.double).eps |
| | is_torch_1_8_plus = V(torch.__version__) >= V("1.8.0") |
| | is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") |
| |
|
| |
|
| | def new_complex_like( |
| | ref: Union[torch.Tensor, ComplexTensor], |
| | real_imag: Tuple[torch.Tensor, torch.Tensor], |
| | ): |
| | if isinstance(ref, ComplexTensor): |
| | return ComplexTensor(*real_imag) |
| | elif is_torch_complex_tensor(ref): |
| | return torch.complex(*real_imag) |
| | else: |
| | raise ValueError( |
| | "Please update your PyTorch version to 1.9+ for complex support." |
| | ) |
| |
|
| |
|
| | def is_torch_complex_tensor(c): |
| | return ( |
| | not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c) |
| | ) |
| |
|
| |
|
| | def is_complex(c): |
| | return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c) |
| |
|
| |
|
| | def to_double(c): |
| | if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c): |
| | return c.to(dtype=torch.complex128) |
| | else: |
| | return c.double() |
| |
|
| |
|
| | def to_float(c): |
| | if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c): |
| | return c.to(dtype=torch.complex64) |
| | else: |
| | return c.float() |
| |
|
| |
|
| | def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs): |
| | if not isinstance(seq, (list, tuple)): |
| | raise TypeError( |
| | "cat(): argument 'tensors' (position 1) must be tuple of Tensors, " |
| | "not Tensor" |
| | ) |
| | if isinstance(seq[0], ComplexTensor): |
| | return FC.cat(seq, *args, **kwargs) |
| | else: |
| | return torch.cat(seq, *args, **kwargs) |
| |
|
| |
|
| | def complex_norm( |
| | c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False |
| | ) -> torch.Tensor: |
| | if not is_complex(c): |
| | raise TypeError("Input is not a complex tensor.") |
| | if is_torch_complex_tensor(c): |
| | return torch.norm(c, dim=dim, keepdim=keepdim) |
| | else: |
| | if dim is None: |
| | return torch.sqrt((c.real**2 + c.imag**2).sum() + EPS) |
| | else: |
| | return torch.sqrt( |
| | (c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS |
| | ) |
| |
|
| |
|
| | def einsum(equation, *operands): |
| | |
| | |
| | |
| | if len(operands) == 1: |
| | if isinstance(operands[0], (tuple, list)): |
| | operands = operands[0] |
| | complex_module = FC if isinstance(operands[0], ComplexTensor) else torch |
| | return complex_module.einsum(equation, *operands) |
| | elif len(operands) != 2: |
| | op0 = operands[0] |
| | same_type = all(op.dtype == op0.dtype for op in operands[1:]) |
| | if same_type: |
| | _einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum |
| | return _einsum(equation, *operands) |
| | else: |
| | raise ValueError("0 or More than 2 operands are not supported.") |
| | a, b = operands |
| | if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor): |
| | return FC.einsum(equation, a, b) |
| | elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)): |
| | if not torch.is_complex(a): |
| | o_real = torch.einsum(equation, a, b.real) |
| | o_imag = torch.einsum(equation, a, b.imag) |
| | return torch.complex(o_real, o_imag) |
| | elif not torch.is_complex(b): |
| | o_real = torch.einsum(equation, a.real, b) |
| | o_imag = torch.einsum(equation, a.imag, b) |
| | return torch.complex(o_real, o_imag) |
| | else: |
| | return torch.einsum(equation, a, b) |
| | else: |
| | return torch.einsum(equation, a, b) |
| |
|
| |
|
| | def inverse( |
| | c: Union[torch.Tensor, ComplexTensor] |
| | ) -> Union[torch.Tensor, ComplexTensor]: |
| | if isinstance(c, ComplexTensor): |
| | return c.inverse2() |
| | else: |
| | return c.inverse() |
| |
|
| |
|
| | def matmul( |
| | a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor] |
| | ) -> Union[torch.Tensor, ComplexTensor]: |
| | |
| | |
| | |
| | if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor): |
| | return FC.matmul(a, b) |
| | elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)): |
| | if not torch.is_complex(a): |
| | o_real = torch.matmul(a, b.real) |
| | o_imag = torch.matmul(a, b.imag) |
| | return torch.complex(o_real, o_imag) |
| | elif not torch.is_complex(b): |
| | o_real = torch.matmul(a.real, b) |
| | o_imag = torch.matmul(a.imag, b) |
| | return torch.complex(o_real, o_imag) |
| | else: |
| | return torch.matmul(a, b) |
| | else: |
| | return torch.matmul(a, b) |
| |
|
| |
|
| | def trace(a: Union[torch.Tensor, ComplexTensor]): |
| | |
| | |
| | return FC.trace(a) |
| |
|
| |
|
| | def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0): |
| | if isinstance(a, ComplexTensor): |
| | return FC.reverse(a, dim=dim) |
| | else: |
| | return torch.flip(a, dims=(dim,)) |
| |
|
| |
|
| | def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]): |
| | """Solve the linear equation ax = b.""" |
| | |
| | |
| | |
| | if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor): |
| | if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor): |
| | return FC.solve(b, a, return_LU=False) |
| | else: |
| | return matmul(inverse(a), b) |
| | elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)): |
| | if torch.is_complex(a) and torch.is_complex(b): |
| | return torch.linalg.solve(a, b) |
| | else: |
| | return matmul(inverse(a), b) |
| | else: |
| | if is_torch_1_8_plus: |
| | return torch.linalg.solve(a, b) |
| | else: |
| | return torch.solve(b, a)[0] |
| |
|
| |
|
| | def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs): |
| | if not isinstance(seq, (list, tuple)): |
| | raise TypeError( |
| | "stack(): argument 'tensors' (position 1) must be tuple of Tensors, " |
| | "not Tensor" |
| | ) |
| | if isinstance(seq[0], ComplexTensor): |
| | return FC.stack(seq, *args, **kwargs) |
| | else: |
| | return torch.stack(seq, *args, **kwargs) |