Source code for dqc.qccalc.ks

from typing import Optional, Dict, Any, List, Union, overload, Tuple
import torch
import xitorch as xt
import xitorch.linalg
import xitorch.optimize
from dqc.system.base_system import BaseSystem
from dqc.qccalc.scf_qccalc import SCF_QCCalc, BaseSCFEngine
from dqc.qccalc.hf import _HFEngine
from dqc.xc.base_xc import BaseXC
from dqc.api.getxc import get_xc
from dqc.utils.datastruct import SpinParam

__all__ = ["KS"]

[docs]class KS(SCF_QCCalc): """ Performing Restricted or Unrestricted Kohn-Sham DFT calculation. Arguments --------- system: BaseSystem The system to be calculated. xc: str The exchange-correlation potential and energy to be used. vext: torch.Tensor or None The external potential applied to the system. It must have the shape of ``(*BV, system.get_grid().shape[-2])`` restricted: bool or None If True, performing restricted Kohn-Sham DFT. If False, it performs the unrestricted Kohn-Sham DFT. If None, it will choose True if the system is unpolarized and False if it is polarized variational: bool If True, then solve the Kohn-Sham equation variationally (i.e. using optimization) instead of using self-consistent iteration. Otherwise, solve it using self-consistent iteration. """ def __init__(self, system: BaseSystem, xc: Union[str, BaseXC], vext: Optional[torch.Tensor] = None, restricted: Optional[bool] = None, variational: bool = False): engine = _KSEngine(system, xc, vext) super().__init__(engine, variational)
class _KSEngine(BaseSCFEngine): """ Private class of Engine to be used with KS. This class provides the calculation of the self-consistency iteration step and the calculation of the post-calculation properties. The reason of this class' existence is the leak in PyTorch: https://github.com/pytorch/pytorch/issues/52140 which can be solved by making a different class than the class where the self-consistent iteration is performed. """ def __init__(self, system: BaseSystem, xc: Union[str, BaseXC], vext: Optional[torch.Tensor] = None, restricted: Optional[bool] = None): self.hf_engine = _HFEngine(system, restricted=restricted) self._polarized = self.hf_engine.polarized # get the xc object if isinstance(xc, str): self.xc: BaseXC = get_xc(xc) else: self.xc = xc system = self.hf_engine.get_system() self._system = system # build and setup basis and grid system.setup_grid() self.hamilton = system.get_hamiltonian() self.hamilton.setup_grid(system.get_grid(), self.xc) # get the orbital info self.orb_weight = system.get_orbweight(polarized=self._polarized) # (norb,) self.norb = SpinParam.apply_fcn(lambda orb_weight: int(orb_weight.shape[-1]), self.orb_weight) # set up the vext linear operator self.knvext_linop = self.hamilton.get_kinnucl() # kinetic, nuclear, and external potential if vext is not None: assert vext.shape[-1] == system.get_grid().get_rgrid().shape[-2] self.knvext_linop = self.knvext_linop + self.hamilton.get_vext(vext) def get_system(self) -> BaseSystem: return self._system @property def shape(self): # returns the shape of the density matrix return self.knvext_linop.shape @property def dtype(self): # returns the dtype of the density matrix return self.knvext_linop.dtype @property def device(self): # returns the device of the density matrix return self.knvext_linop.device @property def polarized(self): return self._polarized def dm2scp(self, dm: Union[torch.Tensor, SpinParam[torch.Tensor]]) -> torch.Tensor: # convert from density matrix to a self-consistent parameter (scp) if isinstance(dm, torch.Tensor): # unpolarized # scp is the fock matrix return self.__dm2fock(dm).fullmatrix() else: # polarized # scp is the concatenated fock matrix fock = self.__dm2fock(dm) mat_u = fock.u.fullmatrix().unsqueeze(0) mat_d = fock.d.fullmatrix().unsqueeze(0) return torch.cat((mat_u, mat_d), dim=0) def scp2dm(self, scp: torch.Tensor) -> Union[torch.Tensor, SpinParam[torch.Tensor]]: # convert the self-consistent parameter (scp) to the density matrix return self.hf_engine.scp2dm(scp) def scp2scp(self, scp: torch.Tensor) -> torch.Tensor: # self-consistent iteration step from a self-consistent parameter (scp) # to an scp dm = self.scp2dm(scp) return self.dm2scp(dm) def aoparams2ene(self, aoparams: torch.Tensor, with_penalty: Optional[float] = None) -> torch.Tensor: # calculate the energy from the atomic orbital params dm, penalty = self.aoparams2dm(aoparams, with_penalty) ene = self.dm2energy(dm) return (ene + penalty) if penalty is not None else ene def aoparams2dm(self, aoparams: torch.Tensor, with_penalty: Optional[float] = None) -> \ Tuple[Union[torch.Tensor, SpinParam[torch.Tensor]], Optional[torch.Tensor]]: # calculate the density matrix and the penalty factor return self.hf_engine.aoparams2dm(aoparams, with_penalty) def pack_aoparams(self, aoparams: Union[torch.Tensor, SpinParam[torch.Tensor]]) -> torch.Tensor: # pack the aoparams from tensor or SpinParam into a single tensor return self.hf_engine.pack_aoparams(aoparams) def unpack_aoparams(self, aoparams: torch.Tensor) -> Union[torch.Tensor, SpinParam[torch.Tensor]]: # unpack the single tensor aoparams to SpinParam or a tensor return self.hf_engine.unpack_aoparams(aoparams) def set_eigen_options(self, eigen_options: Dict[str, Any]) -> None: # set the eigendecomposition (diagonalization) option self.hf_engine.set_eigen_options(eigen_options) def dm2energy(self, dm: Union[torch.Tensor, SpinParam[torch.Tensor]]) -> torch.Tensor: # calculate the energy given the density matrix dmtot = SpinParam.sum(dm) e_core = self.hamilton.get_e_hcore(dmtot) e_elrep = self.hamilton.get_e_elrep(dmtot) e_xc = self.hamilton.get_e_xc(dm) return e_core + e_elrep + e_xc + self._system.get_nuclei_energy() @overload def __dm2fock(self, dm: torch.Tensor) -> xt.LinearOperator: ... @overload def __dm2fock(self, dm: SpinParam[torch.Tensor]) -> SpinParam[xt.LinearOperator]: ... def __dm2fock(self, dm): elrep = self.hamilton.get_elrep(SpinParam.sum(dm)) # (..., nao, nao) core_coul = self.knvext_linop + elrep vxc = self.hamilton.get_vxc(dm) # spin param or tensor (..., nao, nao) return SpinParam.apply_fcn(lambda vxc_: vxc_ + core_coul, vxc) def getparamnames(self, methodname: str, prefix: str = "") -> List[str]: if methodname == "scp2scp": return self.getparamnames("scp2dm", prefix=prefix) + \ self.getparamnames("dm2scp", prefix=prefix) elif methodname == "scp2dm": return self.hf_engine.getparamnames("scp2dm", prefix=prefix + "hf_engine.") elif methodname == "dm2scp": return self.getparamnames("__dm2fock", prefix=prefix) elif methodname == "aoparams2ene": return self.getparamnames("aoparams2dm", prefix=prefix) + \ self.getparamnames("dm2energy", prefix=prefix) elif methodname in ["aoparams2dm", "pack_aoparams", "unpack_aoparams"]: return self.hf_engine.getparamnames(methodname, prefix=prefix + "hf_engine.") elif methodname == "dm2energy": hprefix = prefix + "hamilton." sprefix = prefix + "_system." return self.hamilton.getparamnames("get_e_hcore", prefix=hprefix) + \ self.hamilton.getparamnames("get_e_elrep", prefix=hprefix) + \ self.hamilton.getparamnames("get_e_xc", prefix=hprefix) + \ self._system.getparamnames("get_nuclei_energy", prefix=sprefix) elif methodname == "__dm2fock": hprefix = prefix + "hamilton." return self.hamilton.getparamnames("get_elrep", prefix=hprefix) + \ self.hamilton.getparamnames("get_vxc", prefix=hprefix) + \ self.knvext_linop._getparamnames(prefix=prefix + "knvext_linop.") else: raise KeyError("Method %s has no paramnames set" % methodname) return [] # TODO: to complete