Source code for mlff_attack.cli.make_attack

#!/usr/bin/env python3
"""
CLI entry point for MACE single structure attack.
"""

import argparse
from pathlib import Path
import torch
import numpy as np
import argparse


import matplotlib.pyplot as plt
from mlff_attack.relaxation import (
    load_structure,
    setup_calculator,
)
from mlff_attack.attacks import make_attack, visualize_perturbation

[docs] def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Perform adversarial attack on atomic structures using MACE model", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--input", type=str, default="initial_cifs/chemistry_value_isovalent_0_05_18_traj.cif", help="Path to input CIF file" ) parser.add_argument( "--model", type=str, default="mace-mpa-0-medium.model", help="Path to MACE model file" ) parser.add_argument( "--device", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to run model on" ) parser.add_argument( "--epsilon", type=float, default=0.05, help="Perturbation step size in Angstroms" ) parser.add_argument( "--outdir", type=str, default=None, help="Path to output directory (default: auto-generated from input with '_perturbed' suffix)" ) parser.add_argument( "--target-energy", type=float, default=None, help="Target energy for attack (if None, maximize energy)" ) parser.add_argument( "--visualize", action="store_true", default=True, help="Generate visualization plot" ) parser.add_argument( "--no-visualize", action="store_false", dest="visualize", help="Skip visualization plot generation" ) parser.add_argument( "--type", type=str, default="fgsm", choices=["fgsm", "pgd", "bim"], help="Type of adversarial attack to perform" ) return parser.parse_args()
[docs] def main(): # Parse command line arguments args = parse_args() # Override configuration with command line arguments input_cif = args.input model_path = args.model device = args.device epsilon = args.epsilon target_energy = args.target_energy attack_type = args.type # Determine output path if args.outdir is not None: output_cif = Path(args.outdir) / (Path(input_cif).stem + "_perturbed.cif") else: output_cif = Path(input_cif).with_name(Path(input_cif).stem + "_perturbed.cif") # Load structure print(f"\nLoading structure from: {input_cif}") atoms = load_structure(input_cif) if atoms is None: raise RuntimeError(f"Failed to load structure from {input_cif}") print(f" Loaded {len(atoms)} atoms: {atoms.get_chemical_formula()}") # Generate perturbed structure print(f"\nGenerating perturbed structure with epsilon={epsilon} Å") output_file, perturbed_atoms, attack_details = make_attack( atoms=atoms, model_path=model_path, device=device, epsilon=epsilon, target_energy=target_energy, output_cif=output_cif, attack_type=attack_type ) # Visualize perturbation if args.visualize: print(f"\nVisualizing perturbation") # Store output filename in atoms info for visualization perturbed_atoms.info['filename'] = str(output_cif) fig = visualize_perturbation(atoms, perturbed_atoms, epsilon=epsilon, outdir=Path(output_cif).parent) plt.close(fig) if output_file: return 0 else: return 1
if __name__ == "__main__": exit(main())