MATLAB的MEX函数可以通过编译的机器代码,替代低效的脚本语言提升运行效率(以及隐藏原始代码保护知识产权)。MEX函数最初支持C语言编写,从2018a开始支持基于C++11的“现代”C++编写MEX,并实现更多“现代”特性(主要是程序内存安全性)。目前,MATLAB官方已不推荐继续用传统C语言编写新的MEX函数:C Matrix API - MATLAB & Simulink - MathWorks 中国。
在这个博客,我展示一个搜索质数本源根的程序,详细说明如何把C++代码嵌入MEX函数中。
首先说明什么是本源根:本源根是离散数学上的一个概念,考虑对一个质数 p p p,若正整数 a a a的各次幂除以 p p p的余数正好可以产生 1 1 1到 p − 1 p-1 p−1的所有整数,即 a mod p a \text{ mod } p a mod p、 a 2 mod p a^2 \text{ mod } p a2 mod p、…、 a p − 1 mod p a^{p-1} \text{ mod } p ap−1 mod p各不相同,那么 a a a就是 p p p的一个本源根。本源根在非对称密码学上非常有用,而且显而易见,对于大质数,求出全部本源根需要穷举,这种穷举就非常适合通过C/C++编程来优化。
首先给出代码:
#include "mex.hpp" #include "mexAdapter.hpp" using namespace matlab::data; using matlab::mex::ArgumentList; class MexFunction : public matlab::mex::Function { public: void operator()(ArgumentList outputs, ArgumentList inputs) { std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine(); ArrayFactory factory; // Validate arguments checkArguments(outputs, inputs); TypedArray<double> inputArray = std::move(inputs[0]); int num =(int)inputArray[0]; if (num <= 2){ matlabPtr->feval(u"error", 0, std::vector<Array>({ factory.createScalar("Input must be an integer larger than 2") })); } // 算法实现 int i, j, cur_num; int search_list[num]; bool isPriRoot; std::vector<int> vec_priroots; for (i=2; i<num-1; i++){ isPriRoot = true; for (j=0; j<num; j++) search_list[j] = 0; cur_num = i; for (j=0; j<num - 1; j++){ search_list[cur_num]++; if (search_list[cur_num] >= 2) { isPriRoot = false; break; } cur_num = (cur_num * i) % num; } if (isPriRoot) vec_priroots.push_back(i); } // Assign outputs TypedArray<double> pri_root_Array = factory.createArray<double>({vec_priroots.size(),1}); for (i=0; i<vec_priroots.size(); i++){ pri_root_Array[i] = vec_priroots[i]; } outputs[0] = pri_root_Array; } void checkArguments(ArgumentList outputs, ArgumentList inputs) { std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine(); ArrayFactory factory; if (inputs[0].getType() != ArrayType::DOUBLE || inputs[0].getType() == ArrayType::COMPLEX_DOUBLE || inputs[0].getNumberOfElements() != 1) { matlabPtr->feval(u"error", 0, std::vector<Array>({ factory.createScalar("Input must be an integer larger than 2") })); } if (outputs.size() > 1) { matlabPtr->feval(u"error", 0, std::vector<Array>({ factory.createScalar("Only one output is returned") })); } } };
从头看起,MEX C++严格来说不是在编写函数, 而是在编写一个名为MexFunction
的类,继承自matlab::mex::Function
,然后重载这个类的括号运算符operator()
。
void operator()(ArgumentList outputs, ArgumentList inputs)
与C语言的MEX接口不同,这里通过两个matlab::mex::ArgumentList
容器,分别传入函数的输入inputs
和输出outputs
。它们传入的数据类型为TypedArray<T>
,可以是MATLAB的二维或多维矩阵。
在检验输入数据符合算法要求后(checkArguments(outputs, inputs)
,这里略过该函数的实现),我们用std::move
接收第一个输入Array的引用(这是官方推荐做法,避免因类型转换多生成副本),然后将该Array第一个数据取出,就是我们算法的输入,即待求本源根的质数 p p p:
TypedArray<double> inputArray = std::move(inputs[0]); int num =(int)inputArray[0];
这里需要说明,如果输入参数不止一个,可以用如下循环逐个访问类型为输入数据中的参数。为了避免复制过大的数据矩阵,迭代器采用了引用类型:
TypedArray<double> doubleArray = std::move(inputs[0]); for (auto& elem : doubleArray) { // do something with elem, the type of elem is double }
再之后的算法实现就是搜索每一个小于 p p p的整数的各次幂余数是否重复,这个过程很简单,我也就用C语言实现(不推荐!),不多赘述。
现在说明一些不太需要深究的内容。本代码声明了两个变量:
std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine(); ArrayFactory factory;
其中matlabPtr
主要的作用是通过feval
方法调用MATLAB的函数,在本段代码中只用于输出错误信息;matlab::data::ArrayFactory
类可以通过模板产生类型为TypedArray<T>
的数据,故我们在最后调用该类组装输出数据:
TypedArray<double> pri_root_Array = factory.createArray<double>({vec_priroots.size(),1}); for (i=0; i<vec_priroots.size(); i++){ pri_root_Array[i] = vec_priroots[i]; }
最后,通过
outputs[0] = pri_root_Array;
将组装的数据交给输出容器的第一个数,大功告成。
可以看到,MATLAB通过容器向MEX函数(对象)传入传出数据。我们在编写MEX C++代码的时候,只要能从ArgumentList inputs
取出输入数据,再把输出结果放进ArgumentList outputs
,其他的实现和我们平时编写C++代码没有任何区别。特别值得说明的是,各种C++11的标准库容器(比如std::vector
)都可以在MEX函数中使用,这也方便我们移植已有算法。