Source code for acados_solver_plugins.solver_plugin_generator

# Copyright 2023 ICUBE Laboratory, University of Strasbourg
# License: Apache License, Version 2.0
# Author: Thibault Poignonec (tpoignonec@unistra.fr)

from copy import deepcopy
import os

from acados_template import AcadosOcp, AcadosOcpSolver
from acados_template import ocp_get_default_cmake_builder
# import casadi as ca
from jinja2 import Environment, PackageLoader  # noqa: I100

import numpy as np

from acados_solver_plugins.utils import delete_dir_recursively  # noqa: I100
from acados_solver_plugins.utils import ensure_dir_exists
from acados_solver_plugins.utils import uppercase_to_underscore

_library_name = 'acados_solver_plugins'


# Custom filters for jinja2


def filter_uppercase_to_underscore(name):
    return uppercase_to_underscore(name)


def filter_curly_bracket_list(list_of_index):
    if isinstance(list_of_index, list):
        return '{%s}' % str(list_of_index).strip('[]')
    elif isinstance(list_of_index, np.ndarray):
        return '{%s}' % str(list_of_index.reshape((-1,)).tolist()).strip('[]')
    else:
        raise Exception(
            'jinja2:filter_curly_bracket_list -> unsupported type!!!')


# Plugin generator

[docs] class SolverPluginGenerator: """ Minimalist template-based code generation module to generate \ `acados::AcadosSolver` derived objects. Relevant external API documentation : \ `jinja2 <https://jinja.palletsprojects.com/en/3.0.x/api/>`_, \ `acados_template <https://docs.acados.org/python_interface/index.html#>`_. """ def __init__( self, custom_export_path: str = None, library_name: str = None ): """ Solver plugin generator constructor. :param custom_export_path: Dir where the generated files will \ be exported, defaults to None :type custom_export_path: str, optional :param library_name: Name of the library that the generated files \ will be part of, defaults to None (in which case the library \ is set to "acados_solver_plugins" internally) :type library_name: str, optional """ # Setup paths if custom_export_path is not None: self.__plugin_export_path = deepcopy(custom_export_path) else: self.__plugin_export_path = os.environ.get( 'ACADOS_SOLVER_PLUGINGS_GENERATION_EXPORT_DIR' ) print('SolverPluginGenerator: \ All generated files will be exported to "' + self.__plugin_export_path + '".') # Setup jinja2 environment #: Jinja environment loader initialized as a `jinja2.PackageLoader` \ #: looking for templates in the installation folder of the module. self.jinga_env_loader = PackageLoader(_library_name) self.jinja_env = Environment(loader=self.jinga_env_loader) self.jinja_env.filters[ 'curly_bracket_list'] = filter_curly_bracket_list self.jinja_env.filters[ 'uppercase_to_underscore'] = filter_uppercase_to_underscore # Jinja2 templates self.__solver_pluging_cpp_template = 'acados_solver_plugin.cpp.j2' self.__solver_pluging_hpp_template = 'acados_solver_plugin.hpp.j2' self.__solver_pluging_cmake_template = 'add_pluging_to_build.cmake.j2' self.__solver_pluging_export_xml_template = 'export_plugin.xml.j2' # Misc. options self.__library_name = deepcopy(library_name) if library_name is None: self.__library_name = _library_name self.__clean_c_build_folder = True self.__generate_libplugin_export = True self.__default_plugin_description = \ 'Acados solver plugin based on "acados_solver_base".' self.__flag_regenerate_cmake = '.flag_regenerate_cmake'
[docs] def _generate_file_from_template( self, template_filename, export_filename, **template_render_args): template = self.jinja_env.get_template(template_filename) try: with open(export_filename, 'w') as f: f.write(template.render(**template_render_args)) except Exception as e: print( 'Failed to generate the file ' + '"{}" from the template "{}". Exception: {}'.format( export_filename, template_filename, e)) raise
[docs] def generate_solver_plugin(self, acados_ocp: AcadosOcp, plugin_class_name: str, solver_description: str = None, x_index_map: dict = None, z_index_map: dict = None, p_index_map: dict = None, u_index_map: dict = None) -> AcadosOcpSolver: """ Generate the solver plugin from a python Acados OCP model. :param acados_ocp: Acados OCP model :type acados_ocp: AcadosOcp :param plugin_class_name: Class name for the generated plugin of type "acados::<plugin_class_name>" :type plugin_class_name: str :param solver_description: Short description for the exported plugin, defaults to None :type solver_description: str, optional :param x_index_map: Differential state index map, defaults to None :type x_index_map: dict, optional :param z_index_map: Algebraic state index map, defaults to None :type z_index_map: dict, optional :param p_index_map: Parameters index map, defaults to None :type p_index_map: dict, optional :param u_index_map: Parameters index map, defaults to None :type u_index_map: dict, optional :return: the acados OCP solver generated as a by-product :rtype: AcadosOcpSolver """ # Make sure dir exist and clean if it does path_plugin_dir = os.path.join( self.__plugin_export_path, uppercase_to_underscore(plugin_class_name) ) print('Cleaning the target directory "%s"!' % (path_plugin_dir)) if os.path.exists(path_plugin_dir): delete_dir_recursively(path_plugin_dir) else: ensure_dir_exists(path_plugin_dir) # Rename acados model and collect relevant information acados_ocp.model.name = uppercase_to_underscore(plugin_class_name) acados_ocp.code_export_directory = os.path.join( path_plugin_dir, 'generated_c_code') acados_ocp_json_filename = os.path.join( path_plugin_dir, 'acados_ocp_' + acados_ocp.model.name + '.json') # Generate C-code print('Start generating/building the acados c-code for plugin "%s"!' % ('acados::'+plugin_class_name)) ocp_cmake_builder = ocp_get_default_cmake_builder() ocp_solver = AcadosOcpSolver( acados_ocp, json_file=acados_ocp_json_filename, cmake_builder=ocp_cmake_builder ) # Generate render args if solver_description is None: solver_description = self.__default_plugin_description template_render_args = { 'solver_c_prefix': uppercase_to_underscore(plugin_class_name), 'plugin_class_name': plugin_class_name, 'solver_description': solver_description, 'x_index_map': x_index_map, 'z_index_map': z_index_map, 'p_index_map': p_index_map, 'u_index_map': u_index_map, 'library_name': self.__library_name, 'export_plugin': self.__generate_libplugin_export, } # Generate c++ wrapper from template cpp_impl_file = os.path.join( path_plugin_dir, ''+uppercase_to_underscore(plugin_class_name)+'.cpp' ) hpp_impl_file = os.path.join( path_plugin_dir, ''+uppercase_to_underscore(plugin_class_name)+'.hpp' ) print('Generating "%s" from template...' % (cpp_impl_file)) self._generate_file_from_template( self.__solver_pluging_cpp_template, cpp_impl_file, **template_render_args ) print('Generating "%s" from template...' % (hpp_impl_file)) self._generate_file_from_template( self.__solver_pluging_hpp_template, hpp_impl_file, **template_render_args ) # Generate plugin export and cmake utils cmake_file = os.path.join( path_plugin_dir, 'add_pluging_to_build.cmake') plugin_xml_file = os.path.join(path_plugin_dir, 'export_plugin.xml') print('Generating "%s" from template...' % (cmake_file)) self._generate_file_from_template( self.__solver_pluging_cmake_template, cmake_file, **template_render_args ) print('Generating "%s" from template...' % (plugin_xml_file)) self._generate_file_from_template( self.__solver_pluging_export_xml_template, plugin_xml_file, **template_render_args ) # Clean build files if self.__clean_c_build_folder: # Delete "custom_export_path/c_generated_code/build" folder legacy_build_folder = os.path.join( path_plugin_dir, 'generated_c_code/build/') if os.path.exists(legacy_build_folder): delete_dir_recursively( legacy_build_folder, keep_empty_folder=False) # Rename 'custom_export_path/c_generated_code/CMakeLists.txt' os.rename( os.path.join(path_plugin_dir, 'generated_c_code/CMakeLists.txt'), os.path.join(path_plugin_dir, 'generated_c_code/legacy_CMakeLists.txt') ) # Trigger CMAKE build cmake_flag_name = os.path.join( self.__plugin_export_path, self.__flag_regenerate_cmake ) if (os.path.exists(cmake_flag_name)): try: os.remove(cmake_flag_name) with open(cmake_flag_name, mode='a'): pass except Exception as error: print( 'Warning, failed to reset the CMAKE flag ', '"{cmake_flag_name}". Error: ', error ) else: print( f'Warning, the CMAKE flag "{cmake_flag_name}" does not exist.' ) # return the python wrapper of the acados ocp in case of return ocp_solver