Spaces:
Running
Running
"""Real spherical harmonics in Cartesian form for PyTorch. | |
This is an autogenerated file. See | |
https://github.com/cheind/torch-spherical-harmonics | |
for more information. | |
""" | |
import torch | |
def rsh_cart_0(xyz: torch.Tensor): | |
"""Computes all real spherical harmonics up to degree 0. | |
This is an autogenerated method. See | |
https://github.com/cheind/torch-spherical-harmonics | |
for more information. | |
Params: | |
xyz: (N,...,3) tensor of points on the unit sphere | |
Returns: | |
rsh: (N,...,1) real spherical harmonics | |
projections of input. Ynm is found at index | |
`n*(n+1) + m`, with `0 <= n <= degree` and | |
`-n <= m <= n`. | |
""" | |
return torch.stack( | |
[ | |
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), | |
], | |
-1, | |
) | |
def rsh_cart_1(xyz: torch.Tensor): | |
"""Computes all real spherical harmonics up to degree 1. | |
This is an autogenerated method. See | |
https://github.com/cheind/torch-spherical-harmonics | |
for more information. | |
Params: | |
xyz: (N,...,3) tensor of points on the unit sphere | |
Returns: | |
rsh: (N,...,4) real spherical harmonics | |
projections of input. Ynm is found at index | |
`n*(n+1) + m`, with `0 <= n <= degree` and | |
`-n <= m <= n`. | |
""" | |
x = xyz[..., 0] | |
y = xyz[..., 1] | |
z = xyz[..., 2] | |
return torch.stack( | |
[ | |
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), | |
-0.48860251190292 * y, | |
0.48860251190292 * z, | |
-0.48860251190292 * x, | |
], | |
-1, | |
) | |
def rsh_cart_2(xyz: torch.Tensor): | |
"""Computes all real spherical harmonics up to degree 2. | |
This is an autogenerated method. See | |
https://github.com/cheind/torch-spherical-harmonics | |
for more information. | |
Params: | |
xyz: (N,...,3) tensor of points on the unit sphere | |
Returns: | |
rsh: (N,...,9) real spherical harmonics | |
projections of input. Ynm is found at index | |
`n*(n+1) + m`, with `0 <= n <= degree` and | |
`-n <= m <= n`. | |
""" | |
x = xyz[..., 0] | |
y = xyz[..., 1] | |
z = xyz[..., 2] | |
x2 = x**2 | |
y2 = y**2 | |
z2 = z**2 | |
xy = x * y | |
xz = x * z | |
yz = y * z | |
return torch.stack( | |
[ | |
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), | |
-0.48860251190292 * y, | |
0.48860251190292 * z, | |
-0.48860251190292 * x, | |
1.09254843059208 * xy, | |
-1.09254843059208 * yz, | |
0.94617469575756 * z2 - 0.31539156525252, | |
-1.09254843059208 * xz, | |
0.54627421529604 * x2 - 0.54627421529604 * y2, | |
], | |
-1, | |
) | |
def rsh_cart_3(xyz: torch.Tensor): | |
"""Computes all real spherical harmonics up to degree 3. | |
This is an autogenerated method. See | |
https://github.com/cheind/torch-spherical-harmonics | |
for more information. | |
Params: | |
xyz: (N,...,3) tensor of points on the unit sphere | |
Returns: | |
rsh: (N,...,16) real spherical harmonics | |
projections of input. Ynm is found at index | |
`n*(n+1) + m`, with `0 <= n <= degree` and | |
`-n <= m <= n`. | |
""" | |
x = xyz[..., 0] | |
y = xyz[..., 1] | |
z = xyz[..., 2] | |
x2 = x**2 | |
y2 = y**2 | |
z2 = z**2 | |
xy = x * y | |
xz = x * z | |
yz = y * z | |
return torch.stack( | |
[ | |
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), | |
-0.48860251190292 * y, | |
0.48860251190292 * z, | |
-0.48860251190292 * x, | |
1.09254843059208 * xy, | |
-1.09254843059208 * yz, | |
0.94617469575756 * z2 - 0.31539156525252, | |
-1.09254843059208 * xz, | |
0.54627421529604 * x2 - 0.54627421529604 * y2, | |
-0.590043589926644 * y * (3.0 * x2 - y2), | |
2.89061144264055 * xy * z, | |
0.304697199642977 * y * (1.5 - 7.5 * z2), | |
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, | |
0.304697199642977 * x * (1.5 - 7.5 * z2), | |
1.44530572132028 * z * (x2 - y2), | |
-0.590043589926644 * x * (x2 - 3.0 * y2), | |
], | |
-1, | |
) | |
def rsh_cart_4(xyz: torch.Tensor): | |
"""Computes all real spherical harmonics up to degree 4. | |
This is an autogenerated method. See | |
https://github.com/cheind/torch-spherical-harmonics | |
for more information. | |
Params: | |
xyz: (N,...,3) tensor of points on the unit sphere | |
Returns: | |
rsh: (N,...,25) real spherical harmonics | |
projections of input. Ynm is found at index | |
`n*(n+1) + m`, with `0 <= n <= degree` and | |
`-n <= m <= n`. | |
""" | |
x = xyz[..., 0] | |
y = xyz[..., 1] | |
z = xyz[..., 2] | |
x2 = x**2 | |
y2 = y**2 | |
z2 = z**2 | |
xy = x * y | |
xz = x * z | |
yz = y * z | |
x4 = x2**2 | |
y4 = y2**2 | |
z4 = z2**2 | |
return torch.stack( | |
[ | |
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), | |
-0.48860251190292 * y, | |
0.48860251190292 * z, | |
-0.48860251190292 * x, | |
1.09254843059208 * xy, | |
-1.09254843059208 * yz, | |
0.94617469575756 * z2 - 0.31539156525252, | |
-1.09254843059208 * xz, | |
0.54627421529604 * x2 - 0.54627421529604 * y2, | |
-0.590043589926644 * y * (3.0 * x2 - y2), | |
2.89061144264055 * xy * z, | |
0.304697199642977 * y * (1.5 - 7.5 * z2), | |
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, | |
0.304697199642977 * x * (1.5 - 7.5 * z2), | |
1.44530572132028 * z * (x2 - y2), | |
-0.590043589926644 * x * (x2 - 3.0 * y2), | |
2.5033429417967 * xy * (x2 - y2), | |
-1.77013076977993 * yz * (3.0 * x2 - y2), | |
0.126156626101008 * xy * (52.5 * z2 - 7.5), | |
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), | |
1.48099765681286 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 0.952069922236839 * z2 | |
+ 0.317356640745613, | |
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), | |
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), | |
-1.77013076977993 * xz * (x2 - 3.0 * y2), | |
-3.75501441269506 * x2 * y2 | |
+ 0.625835735449176 * x4 | |
+ 0.625835735449176 * y4, | |
], | |
-1, | |
) | |
def rsh_cart_5(xyz: torch.Tensor): | |
"""Computes all real spherical harmonics up to degree 5. | |
This is an autogenerated method. See | |
https://github.com/cheind/torch-spherical-harmonics | |
for more information. | |
Params: | |
xyz: (N,...,3) tensor of points on the unit sphere | |
Returns: | |
rsh: (N,...,36) real spherical harmonics | |
projections of input. Ynm is found at index | |
`n*(n+1) + m`, with `0 <= n <= degree` and | |
`-n <= m <= n`. | |
""" | |
x = xyz[..., 0] | |
y = xyz[..., 1] | |
z = xyz[..., 2] | |
x2 = x**2 | |
y2 = y**2 | |
z2 = z**2 | |
xy = x * y | |
xz = x * z | |
yz = y * z | |
x4 = x2**2 | |
y4 = y2**2 | |
z4 = z2**2 | |
return torch.stack( | |
[ | |
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), | |
-0.48860251190292 * y, | |
0.48860251190292 * z, | |
-0.48860251190292 * x, | |
1.09254843059208 * xy, | |
-1.09254843059208 * yz, | |
0.94617469575756 * z2 - 0.31539156525252, | |
-1.09254843059208 * xz, | |
0.54627421529604 * x2 - 0.54627421529604 * y2, | |
-0.590043589926644 * y * (3.0 * x2 - y2), | |
2.89061144264055 * xy * z, | |
0.304697199642977 * y * (1.5 - 7.5 * z2), | |
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, | |
0.304697199642977 * x * (1.5 - 7.5 * z2), | |
1.44530572132028 * z * (x2 - y2), | |
-0.590043589926644 * x * (x2 - 3.0 * y2), | |
2.5033429417967 * xy * (x2 - y2), | |
-1.77013076977993 * yz * (3.0 * x2 - y2), | |
0.126156626101008 * xy * (52.5 * z2 - 7.5), | |
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), | |
1.48099765681286 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 0.952069922236839 * z2 | |
+ 0.317356640745613, | |
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), | |
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), | |
-1.77013076977993 * xz * (x2 - 3.0 * y2), | |
-3.75501441269506 * x2 * y2 | |
+ 0.625835735449176 * x4 | |
+ 0.625835735449176 * y4, | |
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), | |
8.30264925952416 * xy * z * (x2 - y2), | |
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), | |
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), | |
0.241571547304372 | |
* y | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
), | |
-1.24747010616985 * z * (1.5 * z2 - 0.5) | |
+ 1.6840846433293 | |
* z | |
* ( | |
1.75 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 0.498988042467941 * z, | |
0.241571547304372 | |
* x | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
), | |
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), | |
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), | |
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), | |
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), | |
], | |
-1, | |
) | |
def rsh_cart_6(xyz: torch.Tensor): | |
"""Computes all real spherical harmonics up to degree 6. | |
This is an autogenerated method. See | |
https://github.com/cheind/torch-spherical-harmonics | |
for more information. | |
Params: | |
xyz: (N,...,3) tensor of points on the unit sphere | |
Returns: | |
rsh: (N,...,49) real spherical harmonics | |
projections of input. Ynm is found at index | |
`n*(n+1) + m`, with `0 <= n <= degree` and | |
`-n <= m <= n`. | |
""" | |
x = xyz[..., 0] | |
y = xyz[..., 1] | |
z = xyz[..., 2] | |
x2 = x**2 | |
y2 = y**2 | |
z2 = z**2 | |
xy = x * y | |
xz = x * z | |
yz = y * z | |
x4 = x2**2 | |
y4 = y2**2 | |
z4 = z2**2 | |
return torch.stack( | |
[ | |
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), | |
-0.48860251190292 * y, | |
0.48860251190292 * z, | |
-0.48860251190292 * x, | |
1.09254843059208 * xy, | |
-1.09254843059208 * yz, | |
0.94617469575756 * z2 - 0.31539156525252, | |
-1.09254843059208 * xz, | |
0.54627421529604 * x2 - 0.54627421529604 * y2, | |
-0.590043589926644 * y * (3.0 * x2 - y2), | |
2.89061144264055 * xy * z, | |
0.304697199642977 * y * (1.5 - 7.5 * z2), | |
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, | |
0.304697199642977 * x * (1.5 - 7.5 * z2), | |
1.44530572132028 * z * (x2 - y2), | |
-0.590043589926644 * x * (x2 - 3.0 * y2), | |
2.5033429417967 * xy * (x2 - y2), | |
-1.77013076977993 * yz * (3.0 * x2 - y2), | |
0.126156626101008 * xy * (52.5 * z2 - 7.5), | |
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), | |
1.48099765681286 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 0.952069922236839 * z2 | |
+ 0.317356640745613, | |
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), | |
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), | |
-1.77013076977993 * xz * (x2 - 3.0 * y2), | |
-3.75501441269506 * x2 * y2 | |
+ 0.625835735449176 * x4 | |
+ 0.625835735449176 * y4, | |
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), | |
8.30264925952416 * xy * z * (x2 - y2), | |
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), | |
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), | |
0.241571547304372 | |
* y | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
), | |
-1.24747010616985 * z * (1.5 * z2 - 0.5) | |
+ 1.6840846433293 | |
* z | |
* ( | |
1.75 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 0.498988042467941 * z, | |
0.241571547304372 | |
* x | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
), | |
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), | |
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), | |
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), | |
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), | |
4.09910463115149 * x**4 * xy | |
- 13.6636821038383 * xy**3 | |
+ 4.09910463115149 * xy * y**4, | |
-2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), | |
0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), | |
0.00584892228263444 | |
* y | |
* (3.0 * x2 - y2) | |
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), | |
0.0701870673916132 | |
* xy | |
* ( | |
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
- 91.875 * z2 | |
+ 13.125 | |
), | |
0.221950995245231 | |
* y | |
* ( | |
-2.8 * z * (1.5 - 7.5 * z2) | |
+ 2.2 | |
* z | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
- 4.8 * z | |
), | |
-1.48328138624466 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
+ 1.86469659985043 | |
* z | |
* ( | |
-1.33333333333333 * z * (1.5 * z2 - 0.5) | |
+ 1.8 | |
* z | |
* ( | |
1.75 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 0.533333333333333 * z | |
) | |
+ 0.953538034014426 * z2 | |
- 0.317846011338142, | |
0.221950995245231 | |
* x | |
* ( | |
-2.8 * z * (1.5 - 7.5 * z2) | |
+ 2.2 | |
* z | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
- 4.8 * z | |
), | |
0.0350935336958066 | |
* (x2 - y2) | |
* ( | |
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
- 91.875 * z2 | |
+ 13.125 | |
), | |
0.00584892228263444 | |
* x | |
* (x2 - 3.0 * y2) | |
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), | |
0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), | |
-2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), | |
0.683184105191914 * x2**3 | |
+ 10.2477615778787 * x2 * y4 | |
- 10.2477615778787 * x4 * y2 | |
- 0.683184105191914 * y2**3, | |
], | |
-1, | |
) | |
def rsh_cart_7(xyz: torch.Tensor): | |
"""Computes all real spherical harmonics up to degree 7. | |
This is an autogenerated method. See | |
https://github.com/cheind/torch-spherical-harmonics | |
for more information. | |
Params: | |
xyz: (N,...,3) tensor of points on the unit sphere | |
Returns: | |
rsh: (N,...,64) real spherical harmonics | |
projections of input. Ynm is found at index | |
`n*(n+1) + m`, with `0 <= n <= degree` and | |
`-n <= m <= n`. | |
""" | |
x = xyz[..., 0] | |
y = xyz[..., 1] | |
z = xyz[..., 2] | |
x2 = x**2 | |
y2 = y**2 | |
z2 = z**2 | |
xy = x * y | |
xz = x * z | |
yz = y * z | |
x4 = x2**2 | |
y4 = y2**2 | |
z4 = z2**2 | |
return torch.stack( | |
[ | |
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), | |
-0.48860251190292 * y, | |
0.48860251190292 * z, | |
-0.48860251190292 * x, | |
1.09254843059208 * xy, | |
-1.09254843059208 * yz, | |
0.94617469575756 * z2 - 0.31539156525252, | |
-1.09254843059208 * xz, | |
0.54627421529604 * x2 - 0.54627421529604 * y2, | |
-0.590043589926644 * y * (3.0 * x2 - y2), | |
2.89061144264055 * xy * z, | |
0.304697199642977 * y * (1.5 - 7.5 * z2), | |
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, | |
0.304697199642977 * x * (1.5 - 7.5 * z2), | |
1.44530572132028 * z * (x2 - y2), | |
-0.590043589926644 * x * (x2 - 3.0 * y2), | |
2.5033429417967 * xy * (x2 - y2), | |
-1.77013076977993 * yz * (3.0 * x2 - y2), | |
0.126156626101008 * xy * (52.5 * z2 - 7.5), | |
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), | |
1.48099765681286 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 0.952069922236839 * z2 | |
+ 0.317356640745613, | |
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), | |
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), | |
-1.77013076977993 * xz * (x2 - 3.0 * y2), | |
-3.75501441269506 * x2 * y2 | |
+ 0.625835735449176 * x4 | |
+ 0.625835735449176 * y4, | |
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), | |
8.30264925952416 * xy * z * (x2 - y2), | |
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), | |
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), | |
0.241571547304372 | |
* y | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
), | |
-1.24747010616985 * z * (1.5 * z2 - 0.5) | |
+ 1.6840846433293 | |
* z | |
* ( | |
1.75 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 0.498988042467941 * z, | |
0.241571547304372 | |
* x | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
), | |
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), | |
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), | |
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), | |
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), | |
4.09910463115149 * x**4 * xy | |
- 13.6636821038383 * xy**3 | |
+ 4.09910463115149 * xy * y**4, | |
-2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), | |
0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), | |
0.00584892228263444 | |
* y | |
* (3.0 * x2 - y2) | |
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), | |
0.0701870673916132 | |
* xy | |
* ( | |
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
- 91.875 * z2 | |
+ 13.125 | |
), | |
0.221950995245231 | |
* y | |
* ( | |
-2.8 * z * (1.5 - 7.5 * z2) | |
+ 2.2 | |
* z | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
- 4.8 * z | |
), | |
-1.48328138624466 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
+ 1.86469659985043 | |
* z | |
* ( | |
-1.33333333333333 * z * (1.5 * z2 - 0.5) | |
+ 1.8 | |
* z | |
* ( | |
1.75 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 0.533333333333333 * z | |
) | |
+ 0.953538034014426 * z2 | |
- 0.317846011338142, | |
0.221950995245231 | |
* x | |
* ( | |
-2.8 * z * (1.5 - 7.5 * z2) | |
+ 2.2 | |
* z | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
- 4.8 * z | |
), | |
0.0350935336958066 | |
* (x2 - y2) | |
* ( | |
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
- 91.875 * z2 | |
+ 13.125 | |
), | |
0.00584892228263444 | |
* x | |
* (x2 - 3.0 * y2) | |
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), | |
0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), | |
-2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), | |
0.683184105191914 * x2**3 | |
+ 10.2477615778787 * x2 * y4 | |
- 10.2477615778787 * x4 * y2 | |
- 0.683184105191914 * y2**3, | |
-0.707162732524596 | |
* y | |
* (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), | |
2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), | |
9.98394571852353e-5 | |
* y | |
* (5197.5 - 67567.5 * z2) | |
* (-10.0 * x2 * y2 + 5.0 * x4 + y4), | |
0.00239614697244565 | |
* xy | |
* (x2 - y2) | |
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), | |
0.00397356022507413 | |
* y | |
* (3.0 * x2 - y2) | |
* ( | |
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) | |
+ 1063.125 * z2 | |
- 118.125 | |
), | |
0.0561946276120613 | |
* xy | |
* ( | |
-4.8 * z * (52.5 * z2 - 7.5) | |
+ 2.6 | |
* z | |
* ( | |
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
- 91.875 * z2 | |
+ 13.125 | |
) | |
+ 48.0 * z | |
), | |
0.206472245902897 | |
* y | |
* ( | |
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 2.16666666666667 | |
* z | |
* ( | |
-2.8 * z * (1.5 - 7.5 * z2) | |
+ 2.2 | |
* z | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
- 4.8 * z | |
) | |
- 10.9375 * z2 | |
+ 2.1875 | |
), | |
1.24862677781952 * z * (1.5 * z2 - 0.5) | |
- 1.68564615005635 | |
* z | |
* ( | |
1.75 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 2.02901851395672 | |
* z | |
* ( | |
-1.45833333333333 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
+ 1.83333333333333 | |
* z | |
* ( | |
-1.33333333333333 * z * (1.5 * z2 - 0.5) | |
+ 1.8 | |
* z | |
* ( | |
1.75 | |
* z | |
* ( | |
1.66666666666667 * z * (1.5 * z2 - 0.5) | |
- 0.666666666666667 * z | |
) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 0.533333333333333 * z | |
) | |
+ 0.9375 * z2 | |
- 0.3125 | |
) | |
- 0.499450711127808 * z, | |
0.206472245902897 | |
* x | |
* ( | |
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 2.16666666666667 | |
* z | |
* ( | |
-2.8 * z * (1.5 - 7.5 * z2) | |
+ 2.2 | |
* z | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
- 4.8 * z | |
) | |
- 10.9375 * z2 | |
+ 2.1875 | |
), | |
0.0280973138060306 | |
* (x2 - y2) | |
* ( | |
-4.8 * z * (52.5 * z2 - 7.5) | |
+ 2.6 | |
* z | |
* ( | |
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
- 91.875 * z2 | |
+ 13.125 | |
) | |
+ 48.0 * z | |
), | |
0.00397356022507413 | |
* x | |
* (x2 - 3.0 * y2) | |
* ( | |
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) | |
+ 1063.125 * z2 | |
- 118.125 | |
), | |
0.000599036743111412 | |
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) | |
* (-6.0 * x2 * y2 + x4 + y4), | |
9.98394571852353e-5 | |
* x | |
* (5197.5 - 67567.5 * z2) | |
* (-10.0 * x2 * y2 + x4 + 5.0 * y4), | |
2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), | |
-0.707162732524596 | |
* x | |
* (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), | |
], | |
-1, | |
) | |
# @torch.jit.script | |
def rsh_cart_8(xyz: torch.Tensor): | |
"""Computes all real spherical harmonics up to degree 8. | |
This is an autogenerated method. See | |
https://github.com/cheind/torch-spherical-harmonics | |
for more information. | |
Params: | |
xyz: (N,...,3) tensor of points on the unit sphere | |
Returns: | |
rsh: (N,...,81) real spherical harmonics | |
projections of input. Ynm is found at index | |
`n*(n+1) + m`, with `0 <= n <= degree` and | |
`-n <= m <= n`. | |
""" | |
x = xyz[..., 0] | |
y = xyz[..., 1] | |
z = xyz[..., 2] | |
x2 = x**2 | |
y2 = y**2 | |
z2 = z**2 | |
xy = x * y | |
xz = x * z | |
yz = y * z | |
x4 = x2**2 | |
y4 = y2**2 | |
# z4 = z2**2 | |
return torch.stack( | |
[ | |
0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]), | |
-0.48860251190292 * y, | |
0.48860251190292 * z, | |
-0.48860251190292 * x, | |
1.09254843059208 * xy, | |
-1.09254843059208 * yz, | |
0.94617469575756 * z2 - 0.31539156525252, | |
-1.09254843059208 * xz, | |
0.54627421529604 * x2 - 0.54627421529604 * y2, | |
-0.590043589926644 * y * (3.0 * x2 - y2), | |
2.89061144264055 * xy * z, | |
0.304697199642977 * y * (1.5 - 7.5 * z2), | |
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, | |
0.304697199642977 * x * (1.5 - 7.5 * z2), | |
1.44530572132028 * z * (x2 - y2), | |
-0.590043589926644 * x * (x2 - 3.0 * y2), | |
2.5033429417967 * xy * (x2 - y2), | |
-1.77013076977993 * yz * (3.0 * x2 - y2), | |
0.126156626101008 * xy * (52.5 * z2 - 7.5), | |
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), | |
1.48099765681286 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 0.952069922236839 * z2 | |
+ 0.317356640745613, | |
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), | |
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), | |
-1.77013076977993 * xz * (x2 - 3.0 * y2), | |
-3.75501441269506 * x2 * y2 | |
+ 0.625835735449176 * x4 | |
+ 0.625835735449176 * y4, | |
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), | |
8.30264925952416 * xy * z * (x2 - y2), | |
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), | |
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), | |
0.241571547304372 | |
* y | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
), | |
-1.24747010616985 * z * (1.5 * z2 - 0.5) | |
+ 1.6840846433293 | |
* z | |
* ( | |
1.75 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 0.498988042467941 * z, | |
0.241571547304372 | |
* x | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
), | |
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), | |
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), | |
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), | |
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), | |
4.09910463115149 * x**4 * xy | |
- 13.6636821038383 * xy**3 | |
+ 4.09910463115149 * xy * y**4, | |
-2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), | |
0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), | |
0.00584892228263444 | |
* y | |
* (3.0 * x2 - y2) | |
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), | |
0.0701870673916132 | |
* xy | |
* ( | |
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
- 91.875 * z2 | |
+ 13.125 | |
), | |
0.221950995245231 | |
* y | |
* ( | |
-2.8 * z * (1.5 - 7.5 * z2) | |
+ 2.2 | |
* z | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
- 4.8 * z | |
), | |
-1.48328138624466 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
+ 1.86469659985043 | |
* z | |
* ( | |
-1.33333333333333 * z * (1.5 * z2 - 0.5) | |
+ 1.8 | |
* z | |
* ( | |
1.75 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 0.533333333333333 * z | |
) | |
+ 0.953538034014426 * z2 | |
- 0.317846011338142, | |
0.221950995245231 | |
* x | |
* ( | |
-2.8 * z * (1.5 - 7.5 * z2) | |
+ 2.2 | |
* z | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
- 4.8 * z | |
), | |
0.0350935336958066 | |
* (x2 - y2) | |
* ( | |
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
- 91.875 * z2 | |
+ 13.125 | |
), | |
0.00584892228263444 | |
* x | |
* (x2 - 3.0 * y2) | |
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), | |
0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), | |
-2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), | |
0.683184105191914 * x2**3 | |
+ 10.2477615778787 * x2 * y4 | |
- 10.2477615778787 * x4 * y2 | |
- 0.683184105191914 * y2**3, | |
-0.707162732524596 | |
* y | |
* (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), | |
2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), | |
9.98394571852353e-5 | |
* y | |
* (5197.5 - 67567.5 * z2) | |
* (-10.0 * x2 * y2 + 5.0 * x4 + y4), | |
0.00239614697244565 | |
* xy | |
* (x2 - y2) | |
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), | |
0.00397356022507413 | |
* y | |
* (3.0 * x2 - y2) | |
* ( | |
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) | |
+ 1063.125 * z2 | |
- 118.125 | |
), | |
0.0561946276120613 | |
* xy | |
* ( | |
-4.8 * z * (52.5 * z2 - 7.5) | |
+ 2.6 | |
* z | |
* ( | |
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
- 91.875 * z2 | |
+ 13.125 | |
) | |
+ 48.0 * z | |
), | |
0.206472245902897 | |
* y | |
* ( | |
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 2.16666666666667 | |
* z | |
* ( | |
-2.8 * z * (1.5 - 7.5 * z2) | |
+ 2.2 | |
* z | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
- 4.8 * z | |
) | |
- 10.9375 * z2 | |
+ 2.1875 | |
), | |
1.24862677781952 * z * (1.5 * z2 - 0.5) | |
- 1.68564615005635 | |
* z | |
* ( | |
1.75 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 2.02901851395672 | |
* z | |
* ( | |
-1.45833333333333 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
+ 1.83333333333333 | |
* z | |
* ( | |
-1.33333333333333 * z * (1.5 * z2 - 0.5) | |
+ 1.8 | |
* z | |
* ( | |
1.75 | |
* z | |
* ( | |
1.66666666666667 * z * (1.5 * z2 - 0.5) | |
- 0.666666666666667 * z | |
) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 0.533333333333333 * z | |
) | |
+ 0.9375 * z2 | |
- 0.3125 | |
) | |
- 0.499450711127808 * z, | |
0.206472245902897 | |
* x | |
* ( | |
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 2.16666666666667 | |
* z | |
* ( | |
-2.8 * z * (1.5 - 7.5 * z2) | |
+ 2.2 | |
* z | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
- 4.8 * z | |
) | |
- 10.9375 * z2 | |
+ 2.1875 | |
), | |
0.0280973138060306 | |
* (x2 - y2) | |
* ( | |
-4.8 * z * (52.5 * z2 - 7.5) | |
+ 2.6 | |
* z | |
* ( | |
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
- 91.875 * z2 | |
+ 13.125 | |
) | |
+ 48.0 * z | |
), | |
0.00397356022507413 | |
* x | |
* (x2 - 3.0 * y2) | |
* ( | |
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) | |
+ 1063.125 * z2 | |
- 118.125 | |
), | |
0.000599036743111412 | |
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) | |
* (-6.0 * x2 * y2 + x4 + y4), | |
9.98394571852353e-5 | |
* x | |
* (5197.5 - 67567.5 * z2) | |
* (-10.0 * x2 * y2 + x4 + 5.0 * y4), | |
2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), | |
-0.707162732524596 | |
* x | |
* (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), | |
5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3), | |
-2.91570664069932 | |
* yz | |
* (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), | |
7.87853281621404e-6 | |
* (1013512.5 * z2 - 67567.5) | |
* (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), | |
5.10587282657803e-5 | |
* y | |
* (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) | |
* (-10.0 * x2 * y2 + 5.0 * x4 + y4), | |
0.00147275890257803 | |
* xy | |
* (x2 - y2) | |
* ( | |
3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) | |
- 14293.125 * z2 | |
+ 1299.375 | |
), | |
0.0028519853513317 | |
* y | |
* (3.0 * x2 - y2) | |
* ( | |
-7.33333333333333 * z * (52.5 - 472.5 * z2) | |
+ 3.0 | |
* z | |
* ( | |
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) | |
+ 1063.125 * z2 | |
- 118.125 | |
) | |
- 560.0 * z | |
), | |
0.0463392770473559 | |
* xy | |
* ( | |
-4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
+ 2.5 | |
* z | |
* ( | |
-4.8 * z * (52.5 * z2 - 7.5) | |
+ 2.6 | |
* z | |
* ( | |
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
- 91.875 * z2 | |
+ 13.125 | |
) | |
+ 48.0 * z | |
) | |
+ 137.8125 * z2 | |
- 19.6875 | |
), | |
0.193851103820053 | |
* y | |
* ( | |
3.2 * z * (1.5 - 7.5 * z2) | |
- 2.51428571428571 | |
* z | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
+ 2.14285714285714 | |
* z | |
* ( | |
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 2.16666666666667 | |
* z | |
* ( | |
-2.8 * z * (1.5 - 7.5 * z2) | |
+ 2.2 | |
* z | |
* ( | |
2.25 | |
* z | |
* (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
- 4.8 * z | |
) | |
- 10.9375 * z2 | |
+ 2.1875 | |
) | |
+ 5.48571428571429 * z | |
), | |
1.48417251362228 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 1.86581687426801 | |
* z | |
* ( | |
-1.33333333333333 * z * (1.5 * z2 - 0.5) | |
+ 1.8 | |
* z | |
* ( | |
1.75 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 0.533333333333333 * z | |
) | |
+ 2.1808249179756 | |
* z | |
* ( | |
1.14285714285714 * z * (1.5 * z2 - 0.5) | |
- 1.54285714285714 | |
* z | |
* ( | |
1.75 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 1.85714285714286 | |
* z | |
* ( | |
-1.45833333333333 | |
* z | |
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) | |
+ 1.83333333333333 | |
* z | |
* ( | |
-1.33333333333333 * z * (1.5 * z2 - 0.5) | |
+ 1.8 | |
* z | |
* ( | |
1.75 | |
* z | |
* ( | |
1.66666666666667 * z * (1.5 * z2 - 0.5) | |
- 0.666666666666667 * z | |
) | |
- 1.125 * z2 | |
+ 0.375 | |
) | |
+ 0.533333333333333 * z | |
) | |
+ 0.9375 * z2 | |
- 0.3125 | |
) | |
- 0.457142857142857 * z | |
) | |
- 0.954110901614325 * z2 | |
+ 0.318036967204775, | |
0.193851103820053 | |
* x | |
* ( | |
3.2 * z * (1.5 - 7.5 * z2) | |
- 2.51428571428571 | |
* z | |
* ( | |
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
+ 2.14285714285714 | |
* z | |
* ( | |
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 2.16666666666667 | |
* z | |
* ( | |
-2.8 * z * (1.5 - 7.5 * z2) | |
+ 2.2 | |
* z | |
* ( | |
2.25 | |
* z | |
* (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) | |
+ 9.375 * z2 | |
- 1.875 | |
) | |
- 4.8 * z | |
) | |
- 10.9375 * z2 | |
+ 2.1875 | |
) | |
+ 5.48571428571429 * z | |
), | |
0.0231696385236779 | |
* (x2 - y2) | |
* ( | |
-4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
+ 2.5 | |
* z | |
* ( | |
-4.8 * z * (52.5 * z2 - 7.5) | |
+ 2.6 | |
* z | |
* ( | |
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) | |
- 91.875 * z2 | |
+ 13.125 | |
) | |
+ 48.0 * z | |
) | |
+ 137.8125 * z2 | |
- 19.6875 | |
), | |
0.0028519853513317 | |
* x | |
* (x2 - 3.0 * y2) | |
* ( | |
-7.33333333333333 * z * (52.5 - 472.5 * z2) | |
+ 3.0 | |
* z | |
* ( | |
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) | |
+ 1063.125 * z2 | |
- 118.125 | |
) | |
- 560.0 * z | |
), | |
0.000368189725644507 | |
* (-6.0 * x2 * y2 + x4 + y4) | |
* ( | |
3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) | |
- 14293.125 * z2 | |
+ 1299.375 | |
), | |
5.10587282657803e-5 | |
* x | |
* (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) | |
* (-10.0 * x2 * y2 + x4 + 5.0 * y4), | |
7.87853281621404e-6 | |
* (1013512.5 * z2 - 67567.5) | |
* (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), | |
-2.91570664069932 | |
* xz | |
* (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), | |
-20.4099464848952 * x2**3 * y2 | |
- 20.4099464848952 * x2 * y2**3 | |
+ 0.72892666017483 * x4**2 | |
+ 51.0248662122381 * x4 * y4 | |
+ 0.72892666017483 * y4**2, | |
], | |
-1, | |
) | |
__all__ = [ | |
"rsh_cart_0", | |
"rsh_cart_1", | |
"rsh_cart_2", | |
"rsh_cart_3", | |
"rsh_cart_4", | |
"rsh_cart_5", | |
"rsh_cart_6", | |
"rsh_cart_7", | |
"rsh_cart_8", | |
] | |
from typing import Optional | |
import torch | |
class SphHarm(torch.nn.Module): | |
def __init__(self, m, n, dtype=torch.float32) -> None: | |
super().__init__() | |
self.dtype = dtype | |
m = torch.tensor(list(range(-m + 1, m))) | |
n = torch.tensor(list(range(n))) | |
self.is_normalized = False | |
vals = torch.cartesian_prod(m, n).T | |
vals = vals[:, vals[0] <= vals[1]] | |
m, n = vals.unbind(0) | |
self.register_buffer("m", tensor=m) | |
self.register_buffer("n", tensor=n) | |
self.register_buffer("l_max", tensor=torch.max(self.n)) | |
f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre() | |
self.register_buffer("f_a", tensor=f_a) | |
self.register_buffer("f_b", tensor=f_b) | |
self.register_buffer("d0_mask_3d", tensor=d0_mask_3d) | |
self.register_buffer("d1_mask_3d", tensor=d1_mask_3d) | |
self.register_buffer("initial_value", tensor=initial_value) | |
def device(self): | |
return next(self.buffers()).device | |
def forward(self, points: torch.Tensor) -> torch.Tensor: | |
"""Computes the spherical harmonics.""" | |
# Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi) | |
B, N, D = points.shape | |
dtype = points.dtype | |
theta, phi = points.view(-1, D).to(self.dtype).unbind(-1) | |
cos_colatitude = torch.cos(phi) | |
legendre = self._gen_associated_legendre(cos_colatitude) | |
vals = torch.stack([self.m.abs(), self.n], dim=0) | |
vals = torch.cat( | |
[ | |
vals.repeat(1, theta.shape[0]), | |
torch.arange(theta.shape[0], device=theta.device) | |
.unsqueeze(0) | |
.repeat_interleave(vals.shape[1], dim=1), | |
], | |
dim=0, | |
) | |
legendre_vals = legendre[vals[0], vals[1], vals[2]] | |
legendre_vals = legendre_vals.reshape(-1, theta.shape[0]) | |
angle = torch.outer(self.m.abs(), theta) | |
vandermonde = torch.complex(torch.cos(angle), torch.sin(angle)) | |
harmonics = torch.complex( | |
legendre_vals * torch.real(vandermonde), | |
legendre_vals * torch.imag(vandermonde), | |
) | |
# Negative order. | |
m = self.m.unsqueeze(-1) | |
harmonics = torch.where( | |
m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics | |
) | |
harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype) | |
return harmonics | |
def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]: | |
"""Generates mask for recurrence relation on the remaining entries. | |
The remaining entries are with respect to the diagonal and offdiagonal | |
entries. | |
Args: | |
l_max: see `gen_normalized_legendre`. | |
Returns: | |
torch.Tensors representing the mask used by the recurrence relations. | |
""" | |
# Computes all coefficients. | |
m_mat, l_mat = torch.meshgrid( | |
torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype), | |
torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype), | |
indexing="ij", | |
) | |
if self.is_normalized: | |
c0 = l_mat * l_mat | |
c1 = m_mat * m_mat | |
c2 = 2.0 * l_mat | |
c3 = (l_mat - 1.0) * (l_mat - 1.0) | |
d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1)) | |
d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1))) | |
else: | |
d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat) | |
d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat) | |
d0_mask_indices = torch.triu_indices(self.l_max + 1, 1) | |
d1_mask_indices = torch.triu_indices(self.l_max + 1, 2) | |
d_zeros = torch.zeros( | |
(self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device | |
) | |
d_zeros[d0_mask_indices] = d0[d0_mask_indices] | |
d0_mask = d_zeros | |
d_zeros = torch.zeros( | |
(self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device | |
) | |
d_zeros[d1_mask_indices] = d1[d1_mask_indices] | |
d1_mask = d_zeros | |
# Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere. | |
i = torch.arange(self.l_max + 1, device=self.device)[:, None, None] | |
j = torch.arange(self.l_max + 1, device=self.device)[None, :, None] | |
k = torch.arange(self.l_max + 1, device=self.device)[None, None, :] | |
mask = (i + j - k == 0).to(self.dtype) | |
d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask) | |
d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask) | |
return (d0_mask_3d, d1_mask_3d) | |
def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
coeff_0 = self.d0_mask_3d[i] | |
coeff_1 = self.d1_mask_3d[i] | |
h = torch.einsum( | |
"ij,ijk->ijk", | |
coeff_0, | |
torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x), | |
) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1)) | |
p_val = p_val + h | |
return p_val | |
def _init_legendre(self): | |
a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device) | |
b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device) | |
if self.is_normalized: | |
# The initial value p(0,0). | |
initial_value: torch.Tensor = torch.tensor( | |
0.5 / (torch.pi**0.5), device=self.device | |
) | |
f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0) | |
f_b = torch.sqrt(2.0 * b_idx + 3.0) | |
else: | |
# The initial value p(0,0). | |
initial_value = torch.tensor(1.0, device=self.device) | |
f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0) | |
f_b = 2.0 * b_idx + 1.0 | |
d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask() | |
return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d | |
def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor: | |
r"""Computes associated Legendre functions (ALFs) of the first kind. | |
The ALFs of the first kind are used in spherical harmonics. The spherical | |
harmonic of degree `l` and order `m` can be written as | |
`Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the | |
normalization factor and θ and φ are the colatitude and longitude, | |
repectively. `N_l^m` is chosen in the way that the spherical harmonics form | |
a set of orthonormal basis function of L^2(S^2). For the computational | |
efficiency of spherical harmonics transform, the normalization factor is | |
used in the computation of the ALFs. In addition, normalizing `P_l^m` | |
avoids overflow/underflow and achieves better numerical stability. Three | |
recurrence relations are used in the computation. | |
Args: | |
l_max: The maximum degree of the associated Legendre function. Both the | |
degrees and orders are `[0, 1, 2, ..., l_max]`. | |
x: A vector of type `float32`, `float64` containing the sampled points in | |
spherical coordinates, at which the ALFs are computed; `x` is essentially | |
`cos(θ)`. For the numerical integration used by the spherical harmonics | |
transforms, `x` contains the quadrature points in the interval of | |
`[-1, 1]`. There are several approaches to provide the quadrature points: | |
Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev | |
method (`scipy.special.roots_chebyu`), and Driscoll & Healy | |
method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier | |
transforms and convolutions on the 2-sphere." Advances in applied | |
mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature | |
points are nearly equal-spaced along θ and provide exact discrete | |
orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose | |
operation, `W` is a diagonal matrix containing the quadrature weights, | |
and `I` is the identity matrix. The Gauss-Chebyshev points are equally | |
spaced, which only provide approximate discrete orthogonality. The | |
Driscoll & Healy qudarture points are equally spaced and provide the | |
exact discrete orthogonality. The number of sampling points is required to | |
be twice as the number of frequency points (modes) in the Driscoll & Healy | |
approach, which enables FFT and achieves a fast spherical harmonics | |
transform. | |
is_normalized: True if the associated Legendre functions are normalized. | |
With normalization, `N_l^m` is applied such that the spherical harmonics | |
form a set of orthonormal basis functions of L^2(S^2). | |
Returns: | |
The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values | |
of the ALFs at `x`; the dimensions in the sequence of order, degree, and | |
evalution points. | |
""" | |
p = torch.zeros( | |
(self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device | |
) | |
p[0, 0] = self.initial_value | |
# Compute the diagonal entries p(l,l) with recurrence. | |
y = torch.cumprod( | |
torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0 | |
) | |
p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y) | |
# torch.diag_indices(l_max + 1) | |
diag_indices = torch.stack( | |
[torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0 | |
) | |
p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag | |
diag_indices = torch.stack( | |
[torch.arange(0, self.l_max, device=x.device)] * 2, dim=0 | |
) | |
# Compute the off-diagonal entries with recurrence. | |
p_offdiag = torch.einsum( | |
"ij,ij->ij", | |
torch.einsum("i,j->ij", self.f_b, x), | |
p[(diag_indices[0], diag_indices[1])], | |
) # p[torch.diag_indices(l_max)]) | |
p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = ( | |
p_offdiag | |
) | |
# Compute the remaining entries with recurrence. | |
if self.l_max > 1: | |
for i in range(2, self.l_max + 1): | |
p = self._recursive(i, p, x) | |
return p | |