Source code for dqc.xc.custom_xc

from abc import abstractmethod, abstractproperty
from typing import Union, List
import torch
from dqc.xc.base_xc import BaseXC
from dqc.utils.datastruct import ValGrad, SpinParam

[docs]class CustomXC(BaseXC, torch.nn.Module): """ Base class of custom xc functional. """ @abstractproperty def family(self) -> int: pass
[docs] @abstractmethod def get_edensityxc(self, densinfo: Union[ValGrad, SpinParam[ValGrad]]) -> torch.Tensor: pass
[docs] def getparamnames(self, methodname: str = "", prefix: str = "") -> List[str]: if methodname == "get_edensityxc": pfix = prefix if not prefix.endswith(".") else prefix[:-1] names = [name for (name, param) in self.named_parameters(prefix=pfix)] return names else: return super().getparamnames(methodname, prefix=prefix)