Posted By

benjamin on 10/29/09


Tagged

distance matlab numerical computing armadillo mex mahalanobis


Versions (?)

Mex Function for Mahalanobis Function


 / Published in: C++
 

URL: http://www.myoutsourcedbrain.com/2009/08/mex-function-for-mahalanobis-function.html

This code is a demonstration of writing mex functions using the armadillo scientific library. Please see my blog post for more details and leave comments there.

  1. #include "mex.h"
  2. #include "math.h"
  3. #include<armadillo>
  4.  
  5. using namespace arma;
  6.  
  7. void importMatlab(mat& A, const mxArray *mxdata){
  8. access::rw(A.mem)=mxGetPr(mxdata);
  9. access::rw(A.n_rows)=mxGetM(mxdata);
  10. access::rw(A.n_cols)=mxGetN(mxdata);
  11. access::rw(A.n_elem)=A.n_rows*A.n_cols;
  12. };
  13.  
  14. void freeVar(mat& A, const double *ptr){
  15. access::rw(A.mem)=ptr;
  16. access::rw(A.n_rows)=1;
  17. access::rw(A.n_cols)=1;
  18. access::rw(A.n_elem)=1;
  19. };
  20.  
  21.  
  22. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
  23. {
  24. if (nrhs != 3)
  25. mexErrMsgTxt("Incorrect number of input arguments");
  26. if (nlhs != 1)
  27. mexErrMsgTxt("Incorrect number of D arguments");
  28.  
  29. mat X(1,1);
  30. const double* Xmem=access::rw(X.mem);
  31. importMatlab(X,prhs[0]);
  32.  
  33. mat Y(1,1);
  34. const double* Ymem=access::rw(Y.mem);
  35. importMatlab(Y,prhs[1]);
  36.  
  37. mat W(1,1);
  38. const double* Wmem=access::rw(W.mem);
  39. importMatlab(W,prhs[2]);
  40.  
  41. // check if the input corresponds to what you are expecting
  42. if(( X.n_cols != Y.n_cols ) || (Y.n_cols != W.n_cols ) )
  43. mexErrMsgTxt("Columns of X, Y, and W must be of equal length!");
  44. if( W.n_rows != W.n_cols )
  45. mexErrMsgTxt("W must be a square matrix!");
  46.  
  47. plhs[0] = mxCreateDoubleMatrix(X.n_rows, Y.n_rows, mxREAL);
  48.  
  49. double *out = mxGetPr(plhs[0]);
  50. mat diff;
  51. mat M;
  52. int k=0;
  53. for(int y=0;y<Y.n_rows;y++)
  54. for(int x=0;x<X.n_rows;x++){
  55. // (X(x,:)-Y(y,:))*W*((X(x,:)-Y(y,:))')
  56. diff=X.row(x)-Y.row(y);
  57. M=diff*W*trans(diff);
  58. (*out++)=M(0);
  59. }
  60.  
  61.  
  62. freeVar(X,Xmem);
  63. freeVar(Y,Ymem);
  64. freeVar(W,Wmem);
  65. return;
  66. }

Report this snippet  

You need to login to post a comment.