Source code for dqc.api.parser

from typing import Union, Tuple
import torch
from dqc.utils.datastruct import AtomZsType, AtomPosType
from dqc.utils.periodictable import get_atomz

__all__ = ["parse_moldesc"]

[docs]def parse_moldesc(moldesc: Union[str, Tuple[AtomZsType, AtomPosType]], dtype: torch.dtype = torch.float64, device: torch.device = torch.device('cpu')) -> Tuple[torch.Tensor, torch.Tensor]: """ Parse the string of molecular descriptor and returns tensors of atomzs and atom positions. Arguments --------- moldesc: str String that describes the system, e.g. ``"H -1 0 0; H 1 0 0"`` for H2 molecule separated by 2 Bohr. dtype: torch.dtype The datatype of the returned atomic positions. device: torch.device The device to store the returned tensors. Returns ------- tuple of 2 tensors The first element is the tensor of atomz, and the second element is the tensor of atomic positions. """ if isinstance(moldesc, str): # TODO: use regex! elmts = [ [ get_atomz(c.strip()) if i == 0 else float(c.strip()) for i, c in enumerate(line.split()) ] for line in moldesc.split(";")] atomzs = torch.tensor([line[0] for line in elmts], device=device) atompos = torch.tensor([line[1:] for line in elmts], dtype=dtype, device=device) else: # tuple of atomzs, atomposs atomzs_raw, atompos_raw = moldesc assert len(atomzs_raw) == len(atompos_raw), "Mismatch length of atomz and atompos" assert len(atomzs_raw) > 0, "Empty atom list" # convert the atomz to tensor if not isinstance(atomzs_raw, torch.Tensor): atomzs = torch.tensor([get_atomz(at) for at in atomzs_raw], device=device) else: atomzs = atomzs_raw.to(device) # already a tensor # convert the atompos to tensor if not isinstance(atompos_raw, torch.Tensor): atompos = torch.as_tensor(atompos_raw, dtype=dtype, device=device) else: atompos = atompos_raw.to(dtype).to(device) # already a tensor # convert to dtype if atomzs is a floating point tensor, not an integer tensor if atomzs.is_floating_point(): atomzs = atomzs.to(dtype) return atomzs, atompos