lammps计算msd

  1. 6. 使用lammps轨迹计算MSD(多起源-师弟版本)

6. 使用lammps轨迹计算MSD(多起源-师弟版本)

import sys
import time
from collections import Counter
import numpy as np
from matplotlib import pyplot as plt
import subprocess


def timer(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"Time elapsed: {end_time - start_time:.4f} seconds")
        return result

    return wrapper


def sort_atoms(atoms: np.ndarray):
    atoms_id = atoms[:, 0].astype(int)
    atoms_type = atoms[:, 1]
    idx = np.lexsort((atoms_id, atoms_type))
    return atoms[idx]


@timer
def read_strj(strj_file: str, intv: int = 1, begin: int = 0):
    step_cmd = f"grep TIMESTEP {strj_file}|wc -l  "
    tot_step = subprocess.run(step_cmd, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    tot_step = int(tot_step.stdout)
    print(f"Total steps: {tot_step}")

    # select lines
    select_step = range(begin, tot_step, intv)
    print(f"These steps in strj file will be selected: {select_step}")

    f = open(strj_file, 'r')

    la_step = 0
    la_num_atoms = 0
    la_lattice = None
    la_atoms = None
    la_id = None
    la_type = None
    la_pos = None

    data_pos = []
    # data_lattice = []
    while True:
        line = f.readline()
        if not line:
            break

        if "TIMESTEP" in line:
            if la_step not in select_step:
                la_step += 1
                continue
            la_step += 1

            while True:
                line = f.readline()
                if not line:
                    break
                if "ITEM" in line:
                    if "NUMBER" in line:
                        la_num_atoms = int(f.readline())
                        continue
                    if "BOX" in line:
                        la_lattice = np.array(sum([f.readline().split() for _ in range(3)], []), dtype=float).reshape(3,
                                                                                                                      3)
                        continue
                    if "ATOMS" in line:
                        la_atoms = np.array(sum([f.readline().split() for _ in range(la_num_atoms)], [])).reshape(
                            la_num_atoms,
                            -1)
                        la_atoms = sort_atoms(la_atoms)
                        la_id = la_atoms[:, 0]
                        la_type = la_atoms[:, 1]
                        la_pos = la_atoms[:, 2:5].astype(float)
                        data_pos.append(la_pos)
                        # data_lattice.append(la_lattice)
                        break
    return la_id, la_type, np.array(data_pos)


@timer
def msd1(x: np.ndarray,
         atoms_types: list):
    unique_types = list(dict.fromkeys(atoms_types))
    cnt = Counter(atoms_types)
    num_per_type = [cnt[i] for i in unique_types]

    step = len(x)

    data_msd_xyz_split = []
    data_msd_tot_split = []
    data_msd_xyz = []
    data_msd_tot = []
    for msd_intv in range(1, step):
        d = x[msd_intv:] - x[:-msd_intv]
        d_square = d * d
        d_square_split = np.split(d_square, np.cumsum(num_per_type), axis=1)[:-1]
        msd_xyz_split = [np.average(i, axis=(0, 1)) for i in d_square_split]
        msd_tot_split = [np.average(i.sum(2)) for i in d_square_split]  # x + y + z
        msd_xyz = np.average(d_square, axis=(0, 1))
        msd_tot = np.average(d_square.sum(2))
        data_msd_xyz_split.append(msd_xyz_split)
        data_msd_tot_split.append(msd_tot_split)
        data_msd_xyz.append(msd_xyz)
        data_msd_tot.append(msd_tot)
    return data_msd_xyz_split, data_msd_tot_split, data_msd_xyz, data_msd_tot


@timer
def plot_split_msd(msd_split,
                   atoms_type):
    row_title = list(dict.fromkeys(atoms_types))
    col_title = ["X", "Y", "Z", "All"]

    x = np.arange(len(msd_split))

    row = msd_split.shape[1]
    col = msd_split.shape[2]

    # fig = plt.figure(figsize=(24, 20), dpi=150)
    # for i in range(row):
    #     for j in range(col):
    #         plot_idx = i * col + j
    #         print(plot_idx)
    #         x = np.array(x)
    #         y = msd_split[:, i, j]
    #
    #         ax = fig.add_subplot(row, col, plot_idx + 1)
    #         ax.plot(x, y)
    #         ax.set_title(f'{row_title[i]}-{col_title[j]}')
    #         ax.set_xlabel("step")
    #         ax.set_ylabel(r"MSD   $A^2$")

    fig = plt.figure(figsize=(24, 5), dpi=200)
    for i in range(col):
        print(i)
        x = np.array(x)
        y = msd_split[:, :, i]

        ax = fig.add_subplot(1, col, i + 1)
        ax.plot(x, y, label=row_title)
        ax.set_title(f'{col_title[i]}')
        ax.set_xlabel("step_intv")
        ax.set_ylabel(r"MSD   $A^2$")
        plt.legend()

    fig.savefig("msd_types")


@timer
def plot_msd(msd_all):
    col_title = ["X", "Y", "Z", "All"]

    x = np.arange(len(msd_all))

    fig = plt.figure(figsize=(24, 5), dpi=200)
    col = msd_all.shape[1]
    for i in range(col):
        plot_idx = i
        x = np.array(x)
        y = msd_all[:, i]

        ax = fig.add_subplot(1, col, plot_idx + 1)
        ax.plot(x, y)
        ax.set_title(f'{col_title[i]}')
        ax.set_xlabel("step_intv")
        ax.set_ylabel(r"MSD   $A^2$")
    fig.savefig("msd_all")


def write_split_msd(msd_split,
              atoms_type,
              write_name:str):
    row_title = list(dict.fromkeys(atoms_types))
    col_title = ["X", "Y", "Z", "All"]

    with open(write_name, "w") as f:
        # title
        f.write("    step")
        for i in range(len(row_title)):
            for j in range(len(col_title)):
                f.write(f"{row_title[i] + '-' + col_title[j]:>8}   ")
        f.write("\n")

        # data
        for l in range(len(msd_split)):
            f.write(f"{l:8}")
            for i in range(len(row_title)):
                for j in range(len(col_title)):
                    f.write(f"{msd_split[l][i][j]:10.4f} ")
            f.write("\n")

def write_msd(msd,
              write_name:str):
    col_title = ["X", "Y", "Z", "All"]

    with open(write_name, "w") as f:
        # titel
        f.write("    step")
        for i in range(len(col_title)):
            f.write(f"{col_title[i]:>8}   ")
        f.write("\n")

        # data
        for l in range(len(msd_split)):
            f.write(f"{l:8}")
            for i in range(len(col_title)):
                f.write(f"{msd[l][i]:10.4f} ")
            f.write("\n")




if __name__ == '__main__':
    strj_file = 'A.lammpstrj'
    intv = 3
    begin = 0

    atoms_id, atoms_types, x = read_strj(strj_file, intv, begin)

    print(atoms_types)

    msd_xyz_split, msd_tot_split, msd_xyz, msd_tot = msd1(x, atoms_types)

    msd_split = np.concatenate([np.array(msd_xyz_split), np.array(msd_tot_split)[:, :, None]],
                               axis=2)  # [msd_intv, types, xyzall]
    print(msd_split.shape)

    msd_all = np.concatenate([np.array(msd_xyz), np.array(msd_tot)[:, None]], axis=1)
    print(msd_all.shape)

    print("plot - typed msd")
    plot_split_msd(msd_split, atoms_types)

    print("plot - msd")
    plot_msd(msd_all)

    print("write - typed msd")
    write_split_msd(msd_split, atoms_types, "./msd-types.txt")

    print("write - msd")
    write_msd(msd_all, "./msd.txt")

转载请注明来源 有问题可通过github提交issue