import numpy as np
import copy
from logging import debug
[docs]
def sqrtbprod(cntrlv, chi, inverse=False, ensemble=False, **kwargs):
"""Multiplies Chi by B**0.5.
"""
debug("Start computing chi -> x = xb + B^0.5 . chi")
datavect = cntrlv.datavect
# Initialize ensemble parameters
nsample = chi.shape[-1] if ensemble else 1
if not ensemble:
chi = chi[:, np.newaxis]
# Initializes output vector
xout = np.zeros((cntrlv.dim, nsample))
# Loop over components of the control vector
components = datavect.components
for comp in components.attributes:
component = getattr(components, comp)
# Skip if component does not have parameters
if not hasattr(component, "parameters"):
continue
for trcr in component.parameters.attributes:
tracer = getattr(component.parameters, trcr)
if not tracer.iscontrol:
continue
debug("Computing sqrtBprod for {}/{}".format(comp, trcr))
x_pointer = tracer.xpointer
x_dim = tracer.dim
chi_dim = tracer.chi_dim
chi_pointer = tracer.chi_pointer
ndates = tracer.ndates
# Dealing with non-diagonal spatial B
if hasattr(tracer, "hcorrelations") \
and tracer.hresol != "global":
debug("Dealing with horizontal correlation")
corr = tracer.hcorrelations
hsqrt_evalues = corr.sqrt_evalues
hevectors = corr.evectors
# Re-stacking chi to period stacks
chi_tempstacks = (
chi[chi_pointer: chi_pointer + chi_dim]
.reshape(tracer.chi_tresoldim,
tracer.chi_vresoldim,
tracer.chi_hresoldim,
-1)
)
# Invert eigen values if compute product by B**{-1/2}
if inverse:
hsqrt_evalues = hsqrt_evalues[:] ** -1
# Stack-multiplication of matrices
# and flattening to control space
# Two options:
# - B = mu . lambda (approximation with reduced space)
# - B = mu . lambda . mu^T (correct value)
if not cntrlv.reduced_chi:
Bspace = (hevectors * hsqrt_evalues[np.newaxis, :])\
.dot(hevectors.T)
else:
Bspace = hevectors * hsqrt_evalues[np.newaxis, :]
chi_tmp = np.transpose(
Bspace.dot(chi_tempstacks),
axes=(1, 2, 0, 3)
).reshape(-1, nsample)
else:
chi_tmp = chi[chi_pointer: chi_pointer + chi_dim]
# Deals with non-diagonal temporal correlations
if hasattr(tracer, "tcorrelations"):
debug("Dealing with temporal correlation")
corr = tracer.tcorrelations
tsqrt_evalues = corr.sqrt_evalues
tevectors = corr.evectors
# Re-stacking chi
chi_horizstacks = (
chi_tmp.reshape(tracer.chi_tresoldim,
tracer.chi_vresoldim,
tracer.hresoldim,
-1)
)
# Invert eigen values if compute product by B**{-1/2}
if inverse:
tsqrt_evalues = tsqrt_evalues[:] ** -1
# Stack-multiplication of matrices
# and flattening to control space
# Two options:
# - B = mu . lambda (approximation with reduced space)
# - B = mu . lambda . mu^T (correct value)
if not cntrlv.reduced_chi:
Btemp = (tevectors * tsqrt_evalues[np.newaxis, :])\
.dot(tevectors.T)
else:
Btemp = tevectors * tsqrt_evalues[np.newaxis, :]
chi_tmp = Btemp.dot(
chi_horizstacks.transpose((1, 2, 0, 3))
).reshape(-1, nsample)
# Deal with vertical resolution
if hasattr(tracer, "vcorrelations"):
debug("Dealing with vertical correlation")
corr = tracer.vcorrelations
vsqrt_evalues = corr.sqrt_evalues
vevectors = corr.evectors
# Re-stacking chi
ndates = tracer.ndates
chi_tempstacks = (
chi_tmp.reshape(ndates,
tracer.chi_vresoldim,
tracer.hresoldim,
-1)
)
# Invert eigen values if compute product by B**{-1/2}
if inverse:
vsqrt_evalues = vsqrt_evalues[:] ** -1
# Stack-multiplication of matrices
# and flattening to control space
# Two options:
# - B = mu . lambda (approximation with reduced space)
# - B = mu . lambda . mu^T (correct value)
if not cntrlv.reduced_chi:
Bvert = (vevectors * vsqrt_evalues[np.newaxis, :]) \
.dot(vevectors.T)
else:
Bvert = vevectors * vsqrt_evalues[np.newaxis, :]
chi_tmp = np.transpose(
Bvert.dot(chi_tempstacks.transpose((0, 2, 1, 3))),
axes=(1, 0, 2, 3)
).reshape(-1, nsample)
# Filling corresponding part in the control vector
xout[x_pointer: x_pointer + x_dim] = chi_tmp
# Invert eigen values if compute product by B**{-1/2}
std = cntrlv.std[:]
if inverse:
std = cntrlv.std[:] ** -1
debug("Successfully computed chi -> x = xb + B^0.5 . chi")
xout = xout * std[:, np.newaxis] + cntrlv.xb[:, np.newaxis]
if ensemble:
return xout
else:
return xout[..., 0]
[docs]
def sqrtbprod_ad(cntrlv, dx, inverse=False, compute_sqrt=True, **kwargs):
if inverse:
debug("Start computing dx -> dchi = B^-0.5 . dx")
if compute_sqrt:
debug("Start computing dx -> dchi = B^0.5 . dx")
else:
debug("Start computing dx -> dchi = B . dx")
datavect = cntrlv.datavect
# Initializes output vector
chiout = np.zeros(cntrlv.chi_dim)
# Loop over components of the control vector
components = datavect.components
for comp in components.attributes:
component = getattr(components, comp)
# Skip if component does not have parameters
if not hasattr(component, "parameters"):
continue
for trcr in component.parameters.attributes:
tracer = getattr(component.parameters, trcr)
if not tracer.iscontrol:
continue
debug("Computing sqrtBprod_ad for {}/{}".format(comp, trcr))
x_pointer = tracer.xpointer
x_dim = tracer.dim
chi_pointer = tracer.chi_pointer
chi_dim = tracer.chi_dim
ndates = tracer.ndates
# x * std
if inverse:
xstd = (
dx[x_pointer: x_pointer + x_dim]
* cntrlv.std[x_pointer: x_pointer + x_dim] ** -1
)
elif compute_sqrt:
xstd = (
dx[x_pointer: x_pointer + x_dim]
* cntrlv.std[x_pointer: x_pointer + x_dim]
)
else:
xstd = (
dx[x_pointer: x_pointer + x_dim]
* cntrlv.std[x_pointer: x_pointer + x_dim] ** 2
)
# Deal with vertical resolution
if hasattr(tracer, "vcorrelations"):
debug("Dealing with vertical correlation")
corr = tracer.vcorrelations
vsqrt_evalues = corr.sqrt_evalues
vevectors = corr.evectors
if inverse:
vsqrt_evalues = vsqrt_evalues[:] ** -1
elif not compute_sqrt:
vsqrt_evalues = vsqrt_evalues[:] ** 2
xstacks = (
xstd.reshape(ndates,
tracer.vresoldim,
tracer.hresoldim)
)
if not cntrlv.reduced_chi:
Bvert = (vevectors * vsqrt_evalues[np.newaxis, :]) \
.dot(vevectors.T)
else:
Bvert = vevectors * vsqrt_evalues[np.newaxis, :]
xstd = np.matmul(Bvert.T, xstacks).flatten()
# Dealing with non-diagonal temporal B
if hasattr(tracer, "tcorrelations"):
debug("Dealing with temporal correlations")
corr = tracer.tcorrelations
tsqrt_evalues = corr.sqrt_evalues
tevectors = corr.evectors
if inverse:
tsqrt_evalues = tsqrt_evalues[:] ** -1
elif not compute_sqrt:
tsqrt_evalues = tsqrt_evalues[:] ** 2
# Re-stacking x to period stacks
x_horizstacks = (
xstd.reshape(ndates,
tracer.vresoldim,
tracer.hresoldim)
)
# Two options:
# - B = mu . lambda (approximation with reduced space)
# - B = mu . lambda . mu^T (correct value)
if not cntrlv.reduced_chi:
Btemp = (tevectors * tsqrt_evalues[np.newaxis, :])\
.dot(tevectors.T)
else:
Btemp = tevectors * tsqrt_evalues[np.newaxis, :]
xstd = np.transpose(
np.matmul(
Btemp.T,
np.transpose(x_horizstacks, axes=(1, 0, 2))),
axes=(1, 0, 2)).flatten()
# Dealing with non-diagonal spatial B
if hasattr(tracer, "hcorrelations") \
and tracer.hresol != "global":
debug("Dealing with horizontal correlations")
corr = tracer.hcorrelations
hsqrt_evalues = corr.sqrt_evalues
hevectors = corr.evectors
if inverse:
hsqrt_evalues = hsqrt_evalues[:] ** -1
elif not compute_sqrt:
hsqrt_evalues = hsqrt_evalues[:] ** 2
# Re-stacking x to horizontal stacks
x_tempstacks = xstd.reshape(
(tracer.hresoldim, tracer.vresoldim, tracer.chi_tresoldim),
order="F"
)
# Stack-multiplication of matrices
# and flattening to control space
# Two options:
# - B = mu . lambda (approximation with reduced space)
# - B = mu . lambda . mu^T (correct value)
if not cntrlv.reduced_chi:
Bspace = (hevectors * hsqrt_evalues[np.newaxis, :])\
.dot(hevectors.T)
else:
Bspace = hevectors * hsqrt_evalues[np.newaxis, :]
xstd = np.transpose(
np.matmul(Bspace.T,
np.transpose(x_tempstacks, axes=(1, 0, 2))),
axes=(1, 0, 2)).flatten(order="F")
# Filling Chi
chiout[chi_pointer: chi_pointer + chi_dim] = xstd
if inverse:
debug("Successfully computed dx -> dchi = B^-0.5 . dx")
elif compute_sqrt:
debug("Successfully computed dx -> dchi = B^0.5 . dx")
else:
debug("Successfully computed dx -> dchi = B . dx")
return chiout