|
/* Version 2009-8-4 18:30:30
Lin Renwen
linrenwen@gmail.com
SUFE SHANGHAI
*/
/*
Purpose: A replication of ndfun which is write by Peter Boettcher
Multiple-page A and Multiple-page B
for index = 1:end
A(:,:, index) \ B(:,:, index);
end
Multiple-page A and Single-page B
for index = 1:end
A(:,:, index) \ B
end
where index may be multi-dimensional;
AMD
Vista 64 bit
Matlab R2009a
Visual Studio 2008
*/
#include "mex.h"
#include "matrix.h"
#include <stdlib.h>
#include "blas.h"
#include <math.h>
#include <string.h>
#include "lapack.h"
#if defined(__OS2__) || defined(__WINDOWS__) || defined(WIN32) || defined(_MSC_VER)
#define BLASCALL(f) f
#else
#define BLASCALL(f) f ## _
#endif
void compute_lu(double *X, ptrdiff_t *ipivot, ptrdiff_t m, double *V, double *XT, ptrdiff_t *ISGN, ptrdiff_t *ISAVE);
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double *A, *B, *C, one = 1.0, zero = 0.0;
double *V;
double XT;
//double CT;
mwSize numDimsA, numDimsB, numDimsC, numPagesA=1, numPagesB =1;
const mwSize *dimsA=NULL, *dimsB=NULL;
mwSize *dimsC = NULL;
mwSize strideA, strideB, strideC;
ptrdiff_t mA, nA, mB, nB;
char *chn = "N";
int i=0, PagesFlagA=100, PagesFlagB = 100;
ptrdiff_t *ipivot = NULL, info;
double *scratchA = NULL;
ptrdiff_t *ISAVE = NULL;
ptrdiff_t *ISGN = NULL;
if(nrhs != 2){
mexErrMsgTxt("need two input!");
}
if(nlhs != 1){
mexPrintf("NB, need one input ! !! !!!\n");
}
/*if(prhs[0] ==NULL || prhs[1] ==NULL ){
mexErrMsgTxt("Inputs should be NULL");
}*/
//CT = (double)*prhs[1];
numDimsA = mxGetNumberOfDimensions(prhs[0]);
numDimsB = mxGetNumberOfDimensions(prhs[1]);
dimsA = mxGetDimensions(prhs[0]);
dimsB = mxGetDimensions(prhs[1]);
mA = dimsA[0];
nA = dimsA[1];
mB = dimsB[0];
nB = dimsB[1];
//Calculate number of pages of A
if(numDimsA == 2){
numPagesA = 1;
PagesFlagA = 0; // '0' stands for A is 2-D
}
else if(numDimsA > 2){
for(i=2; i<numDimsA; i++){
numPagesA *= dimsA[i];
}
}
else{
mexErrMsgTxt("A should not be sclars");
}
//Calculate number of pages of B
if(numDimsB == 2){
numPagesB = 1;
PagesFlagB = 0; // '0' stands for A is 2-D
}
else if(numDimsB > 2){
for(i=2; i<numDimsB; i++){
numPagesB *= dimsB[i];
}
}
else{
mexErrMsgTxt("B should not be sclars");
};
/*for double check*/
/*mexPrintf("A is %d\t by %d\t by %d \n",dimsA[0], dimsA[1],dimsA[2]);
mexPrintf("B is %d\t by %d\t by %d \n",dimsB[0], dimsB[1],dimsB[2]);*/
if ( mA != nA || nA != mB){
mexErrMsgTxt("A is not a square matrix. OR or Inner dimensions of matrix multiply do not match.");
}
strideA = (mwSize)(mA*nA);
strideB = (mwSize)(mB*nB);
strideC = (mwSize)(mA*nB);
A = mxGetPr(prhs[0]);
B = mxGetPr(prhs[1]);
ipivot = (ptrdiff_t *)mxMalloc(mA*sizeof(ptrdiff_t));
if(PagesFlagB == 0){
numDimsC = numDimsA;
dimsC = (mwSize *)mxMalloc(numDimsC*sizeof(mwSize));
dimsC[0] = mB;
dimsC[1] = nB;
for(i=2; i<numDimsA; i++)
dimsC[i] = dimsA[i];
plhs[0] = mxCreateNumericArray(numDimsC, dimsC, mxDOUBLE_CLASS, mxREAL);
C = mxGetPr(plhs[0]);
}
else{
numDimsC = numDimsB;
dimsC = (mwSize *)mxMalloc(numDimsC*sizeof(mwSize));
dimsC = (mwSize *)dimsB;
plhs[0] = mxCreateNumericArray(numDimsC, dimsC, mxDOUBLE_CLASS, mxREAL);
C = mxGetPr(plhs[0]);
}
scratchA = (double *)mxMalloc(mA*nA*sizeof(double));
V = (double *)mxMalloc(mA*mA*sizeof(double));
ISGN = (ptrdiff_t *)mxMalloc(mA*sizeof(ptrdiff_t));
ISAVE = (ptrdiff_t *)mxMalloc(3*sizeof(ptrdiff_t));
/* Multiple-page A. Multiple-page B. */
for(i=0; i<numPagesA; i++) {
memcpy(scratchA, A + i*strideA, mA*nA*sizeof(double));
compute_lu(scratchA, ipivot, mA, V, &XT, ISGN, ISAVE);
if(PagesFlagB == 0){
memcpy(C+i*strideC, B, mB*nB*sizeof(double));
}
else{
memcpy(C+i*strideC, B+i*strideB, mB*nB*sizeof(double));
}
BLASCALL(dgetrs)("N", &mA, &nB, scratchA, &mA, ipivot, C + i*strideC, &mA, &info);
};
/*we should use mxFree in matlab*/
/* mxFree(scratchA);
mxFree(V);
mxFree(dimsC);
mxFree(ipivot);
mxFree(ISGN);
mxFree(ISAVE);*/
//*prhs[1] = (mxArray)CT;
return;
} //End mexFunction
/*-------------------------*//* APPENDIX *//*---------------------------*/
/* Wrapper function for LU decomposition. Optionally checks
singularity of result. For efficiency, pass in the scratch
buffers. Result appears in-place. See BLAS docs on DGETRF and
DGECON for required scratch buffer sizes. */
void compute_lu(double *X, ptrdiff_t *ipivot, ptrdiff_t m, double *V, double *XT, ptrdiff_t *ISGN, ptrdiff_t *ISAVE)
{
/* INPUT-OUTPUT of dgetrf() */
//ptrdiff_t m
//double *X
//ptrdiff_t *ipivot;
ptrdiff_t info;
/* INPUT-OUTPUT of dlacn2() */
//ptrdiff_t m
double EST = 0;
ptrdiff_t KASE = 0;
/* Others Definition for check_singular*/
ptrdiff_t check_singular = 1;
char errmsg[255];
double eps = mxGetEps();
/* LU Decomposition */
BLASCALL(dgetrf)(&m, &m, X, &m, ipivot, &info);
*XT = *X;
/* Check singularity */
if(check_singular) {
if(info>0){
mexWarnMsgTxt("Matrix is singular to working precision");
}
else {
// mexPrintf("m \t %d \t Est \t %d \t KASE \t %d \n", m , EST, KASE);
BLASCALL(dlacn2)( &m, V, XT, ISGN, &EST, &KASE, ISAVE);
// mexPrintf("m \t %d \t Est \t %d \t KASE \t %d \n", m , EST, KASE);
BLASCALL(dlacn2)( &m, V, XT, ISGN, &EST, &KASE, ISAVE);
// mexPrintf("m \t %d \t Est \t %d \t KASE \t %d \n", m , EST, KASE);
BLASCALL(dlacn2)( &m, V, XT, ISGN, &EST, &KASE, ISAVE);
// mexPrintf("m \t %d \t Est \t %d \t KASE \t %d \n", m , EST, KASE);
if(EST < eps) {
sprintf(errmsg, "%s\n %s RCOND = %e.", "Matrix is close to singular or badly scaled.",
"Results may be inaccurate.", EST);
mexWarnMsgTxt(errmsg);
}
} // End else
} // End if(check_singular)
return;
} // End compute_lu
/*-------------------------*//* REFERENCE *//*-------------------------*/
/*extern void dlacn2(
ptrdiff_t *n,
double *v,
double *x,
ptrdiff_t *isgn,
double *est,
ptrdiff_t *kase,
ptrdiff_t *isave
);*/
/*extern void dgetrf(
ptrdiff_t *m,
ptrdiff_t *n,
double *a,
ptrdiff_t *lda,
ptrdiff_t *ipiv,
ptrdiff_t *info
);*/
//For DLACN2:
/** EST (input/output) DOUBLE PRECISION
* On entry with KASE = 1 or 2 and ISAVE(1) = 3, EST should be
* unchanged from the previous call to DLACN2.
* On exit, EST is an estimate (a lower bound) for norm(A).
*
For DLACON:
* EST (input/output) DOUBLE PRECISION
* On entry with KASE = 1 or 2 and JUMP = 3, EST should be
* unchanged from the previous call to DLACON.
* On exit, EST is an estimate (a lower bound) for norm(A). */
|