tkf on 09/05/08

## Who likes this?

1 person have marked this snippet as a favorite

# Example of Numpy/C API

/ Published in: C

Some macros are from the "Python Scripting for Computational Science" : http://folk.uio.no/hpl/scripting/

`/* dtlsmodule.c */#include <math.h>#include <stdio.h> #include <Python.h>#include "structmember.h"#include <numpy/arrayobject.h> /* ================================================================= MACROS */#define QUOTE(s) # s   /* turn s into string "s" */#define NDIM_CHECK(a, expected_ndim, rt_error)				\  if (PyArray_NDIM(a) != expected_ndim) {				\    PyErr_Format(PyExc_ValueError,					\		 "%s array is %d-dimensional, but expected to be %d-dimensional", \		 QUOTE(a), PyArray_NDIM(a), expected_ndim);		\    return rt_error;							\  }#define DIM_CHECK(a, dim, expected_length, rt_error)			\  if (dim > PyArray_NDIM(a)) {						\    PyErr_Format(PyExc_ValueError,					\		 "%s array has no %d dimension (max dim. is %d)",	\		 QUOTE(a), dim, PyArray_NDIM(a));			\    return rt_error;							\  }									\  if (PyArray_DIM(a, dim) != expected_length) {				\    PyErr_Format(PyExc_ValueError,					\		 "%s array has wrong %d-dimension=%d (expected %d)",	\		 QUOTE(a), dim, PyArray_DIM(a, dim), expected_length);	\    return rt_error;							\  }#define TYPE_CHECK(a, tp, rt_error)					\  if (PyArray_TYPE(a) != tp) {						\    PyErr_Format(PyExc_TypeError,					\		 "%s array is not of correct type (%d)", QUOTE(a), tp); \    return rt_error;							\  }#define CALLABLE_CHECK(func, rt_error)				\  if (!PyCallable_Check(func)) {				\    PyErr_Format(PyExc_TypeError,				\		 "%s is not a callable function", QUOTE(func)); \    return rt_error;						\  } #define DIND1(a, i) *((double *) PyArray_GETPTR1(a, i))#define DIND2(a, i, j) *((double *) PyArray_GETPTR2(a, i, j))#define DIND3(a, i, j, k) *((double *) Py_Array_GETPTR3(a, i, j, k)) #define IIND1(a, i) *((int *) PyArray_GETPTR1(a, i))#define IIND2(a, i, j) *((int *) PyArray_GETPTR2(a, i, j))#define IIND3(a, i, j, k) *((int *) Py_Array_GETPTR3(a, i, j, k))  #define DEF_PYARRAY_GETTER(funcname, selftype, valname)	\  static PyObject *					\  funcname(selftype *self, void *closure)		\  {							\    Py_INCREF(self->valname);				\    return PyArray_Return(self->valname);		\  }#define DEF_PYARRAY_SETTER(funcname, selftype, valname, arraydim)	\  static int								\  funcname(selftype *self, PyObject *value, void *closure)		\  {									\    if (value == NULL) {						\      PyErr_SetString( PyExc_TypeError,					\		       "Cannot delete the last attribute");		\      return -1;							\    }									\    if ( PyArray_Check(value) != 1 ){					\      PyErr_Format( PyExc_ValueError,					\		    "value is not of type numpy array");		\      return -1;							\    }									\    if ( PyArray_NDIM(value) != arraydim ){				\      PyErr_Format( PyExc_ValueError,					\		    "value array's dimension %d != arraydim",		\		    PyArray_NDIM(value));				\      return -1;							\    }									\    if ( PyArray_TYPE(value) != NPY_DOUBLE ){				\      PyErr_Format( PyExc_ValueError,					\		    "value array is not of type 'Python float'");	\      return -1;							\    }									\    Py_DECREF(self->valname);						\    Py_INCREF(value);							\    self->valname = (PyArrayObject *) value;				\    return 0;								\  } /* ========================================================== DTLSys struct */typedef struct {  PyObject_HEAD  PyArrayObject *wt;  PyArrayObject *bs;  PyArrayObject *xt;} DTLSys; /* ============================================================ Declaration */static void       DTLSys_dealloc(DTLSys* self);static PyObject * DTLSys_new(PyTypeObject *type, PyObject *args, PyObject *kwds);static int        DTLSys_init(DTLSys *self, PyObject *args, PyObject *kwds); static PyMemberDef DTLSys_members[] = {    {NULL}  /* Sentinel */}; static PyObject * DTLSys_get_wt(DTLSys *self,                  void *closure);static int        DTLSys_set_wt(DTLSys *self, PyObject *value, void *closure);static PyObject * DTLSys_get_bs(DTLSys *self,                  void *closure);static int        DTLSys_set_bs(DTLSys *self, PyObject *value, void *closure);static PyObject * DTLSys_get_xt(DTLSys *self,                  void *closure);static int        DTLSys_set_xt(DTLSys *self, PyObject *value, void *closure); static PyGetSetDef DTLSys_getseters[] = {    {"wt", (getter)DTLSys_get_wt, (setter)DTLSys_set_wt, "Matrix", NULL},    {"bs", (getter)DTLSys_get_bs, (setter)DTLSys_set_bs, "Vector", NULL},    {"xt", (getter)DTLSys_get_xt, (setter)DTLSys_set_xt, "Vector time sequence", NULL},    {NULL}  /* Sentinel */}; static int       _DTLSys_check_sys_conf(DTLSys *self);static PyObject * DTLSys_check_sys_conf(DTLSys *self);static PyObject * DTLSys_make_tms(DTLSys *self, PyObject *args); static PyMethodDef DTLSys_methods[] = {  {"check_sys_conf", (PyCFunction)DTLSys_check_sys_conf, METH_NOARGS, "Check if system config is correct"},  {"make_tms", (PyCFunction)DTLSys_make_tms, METH_VARARGS, "Make TiMe Series"},  {NULL}  /* Sentinel */}; static PyTypeObject DTLSysType = {    PyObject_HEAD_INIT(NULL)    0,				/*ob_size*/    "dtls.DTLSys",		/*tp_name*/    sizeof(DTLSys),             /*tp_basicsize*/    0,				/*tp_itemsize*/    (destructor)DTLSys_dealloc, /*tp_dealloc*/    0,				/*tp_print*/    0,				/*tp_getattr*/    0,				/*tp_setattr*/    0,				/*tp_compare*/    0,				/*tp_repr*/    0,				/*tp_as_number*/    0,				/*tp_as_sequence*/    0,				/*tp_as_mapping*/    0,				/*tp_hash */    0,				/*tp_call*/    0,				/*tp_str*/    0,				/*tp_getattro*/    0,				/*tp_setattro*/    0,				/*tp_as_buffer*/    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,	/*tp_flags*/    "DTLSys objects",           /* tp_doc */    0,				/* tp_traverse */    0,				/* tp_clear */    0,				/* tp_richcompare */    0,				/* tp_weaklistoffset */    0,				/* tp_iter */    0,				/* tp_iternext */    DTLSys_methods,             /* tp_methods */    DTLSys_members,             /* tp_members */    DTLSys_getseters,           /* tp_getset */    0,				/* tp_base */    0,				/* tp_dict */    0,				/* tp_descr_get */    0,				/* tp_descr_set */    0,				/* tp_dictoffset */    (initproc)DTLSys_init,      /* tp_init */    0,				/* tp_alloc */    DTLSys_new,                 /* tp_new */}; static PyMethodDef module_methods[] = {    {NULL}  /* Sentinel */}; #ifndef PyMODINIT_FUNC	/* declarations for DLL import/export */#define PyMODINIT_FUNC void#endifPyMODINIT_FUNCinitdtls(void) {  PyObject* m;  if (PyType_Ready(&DTLSysType) < 0){ return; }   m = Py_InitModule3("dtls", module_methods,		     "Example module that creates an extension type.");  if (m == NULL){ return; }   Py_INCREF(&DTLSysType);  PyModule_AddObject(m, "DTLSys", (PyObject *)&DTLSysType);  import_array();   /* required NumPy initialization */} /* ================================================================= Define */static voidDTLSys_dealloc(DTLSys* self){  Py_XDECREF(self->wt);  Py_XDECREF(self->bs);  Py_XDECREF(self->xt);  self->ob_type->tp_free((PyObject*)self);} static PyObject *DTLSys_new(PyTypeObject *type, PyObject *args, PyObject *kwds){  DTLSys *self;  npy_intp wt_dims[2] = {0,0};  npy_intp bs_dims[1] = {0};  npy_intp xt_dims[2] = {0,0};   self = (DTLSys *)type->tp_alloc(type, 0);  if (self != NULL) {    self->wt = (PyArrayObject *) PyArray_SimpleNew(2, wt_dims, NPY_DOUBLE);    if (self->wt == NULL){ Py_DECREF(self); return NULL; }    self->bs = (PyArrayObject *) PyArray_SimpleNew(1, bs_dims, NPY_DOUBLE);    if (self->bs == NULL){ Py_DECREF(self); return NULL; }    self->xt = (PyArrayObject *) PyArray_SimpleNew(2, xt_dims, NPY_DOUBLE);    if (self->xt == NULL){ Py_DECREF(self); return NULL; }  }   return (PyObject *)self;} static intDTLSys_init(DTLSys *self, PyObject *args, PyObject *kwds){  PyArrayObject *wt, *bs, *tmp;  int t_max;  npy_intp xt_dims[2] = {0,0};   if ( !PyArg_ParseTuple( args, "O!O!i:DTLSys.init",			  &PyArray_Type, &wt,			  &PyArray_Type, &bs,			  &t_max)       ) {    return -1; /* PyArg_ParseTuple has raised an exception */  }   if ( wt==NULL || bs==NULL ) {    printf("getting args failed\n"); return -1;  }  if ( t_max < 0 ) {    printf("t_max (3rd arg) must be positive int\n"); return -1;  }   xt_dims[0] = PyArray_DIM(wt,0);  xt_dims[1] = t_max;  self->xt = (PyArrayObject *) PyArray_SimpleNew(2, xt_dims, NPY_DOUBLE);  if (self->xt == NULL){     printf("creating %dx%d array failed\n", (int)xt_dims[0], (int)xt_dims[1]);    return -1;  }   tmp = self->wt; Py_INCREF(wt); self->wt = wt; Py_DECREF(tmp);  tmp = self->bs; Py_INCREF(bs); self->bs = bs; Py_DECREF(tmp);   if( _DTLSys_check_sys_conf(self) != 0 ){    PyErr_Clear();    printf("DTLSys config is not correct!\n");  }  return 0;} DEF_PYARRAY_GETTER( DTLSys_get_wt, DTLSys, wt )DEF_PYARRAY_SETTER( DTLSys_set_wt, DTLSys, wt, 2 )DEF_PYARRAY_GETTER( DTLSys_get_bs, DTLSys, bs )DEF_PYARRAY_SETTER( DTLSys_set_bs, DTLSys, bs, 1 )DEF_PYARRAY_GETTER( DTLSys_get_xt, DTLSys, xt )DEF_PYARRAY_SETTER( DTLSys_set_xt, DTLSys, xt, 2 ) static int_DTLSys_check_sys_conf(DTLSys *self){  int vecsize;   NDIM_CHECK(self->wt, 2, -1); TYPE_CHECK(self->wt, NPY_DOUBLE, -1);  NDIM_CHECK(self->bs, 1, -1); TYPE_CHECK(self->bs, NPY_DOUBLE, -1);   vecsize = PyArray_DIM(self->wt,0);  if (vecsize != PyArray_DIM(self->wt,1) ) {    PyErr_Format( PyExc_ValueError, "self.wt must be square");    return -1;  }  if (vecsize != PyArray_DIM(self->bs,0) ) {    PyErr_Format( PyExc_ValueError, "self.bs and self.wt[0] must be same shape");    return -1;  }  if (vecsize != PyArray_DIM(self->xt,0) ) {    PyErr_Format( PyExc_ValueError, "self.xt[,0] and self.wt[0] must be same shape");    return -1;  }  return 0;} static PyObject *DTLSys_check_sys_conf(DTLSys *self){  if( _DTLSys_check_sys_conf(self) != 0 ){    PyErr_Clear();    Py_RETURN_FALSE;  }  Py_RETURN_TRUE;} static PyObject *DTLSys_make_tms(DTLSys *self, PyObject *args){  int vecsize, t_max, i, j, t;  if( _DTLSys_check_sys_conf(self) != 0 ){    return NULL;  }  vecsize = PyArray_DIM(self->wt,0);  t_max   = PyArray_DIM(self->xt,1);   for (t = 1; t < t_max; t++) {    for (i = 0; i < vecsize; i++) {      DIND2(self->xt,i,t) = DIND1(self->bs,i);      for (j = 0; j < vecsize; j++) {	DIND2(self->xt,i,t) += DIND2(self->wt,i,j) * DIND2(self->xt,j,t-1);      }    }  }  return Py_BuildValue("");  /* return None */} /*# setup.py# build command : python setup.py build build_ext --inplacefrom numpy.distutils.core import setup, Extensionimport os, numpy name = 'dtls'sources = ['dtlsmodule.c'] include_dirs = [    numpy.get_include()    ] setup( name = name,       include_dirs = include_dirs,       ext_modules = [Extension(name, sources)]       )*/ /*# test codeimport scipy, pylabimport dtls t_max = 200rot = 5.0 * 2.0 * scipy.pi / t_maxwt = scipy.array([    [  scipy.cos(rot), scipy.sin(rot) ],    [ -scipy.sin(rot), scipy.cos(rot) ]    ])wt *= 0.99bs = scipy.array([0.,0.]) a=dtls.DTLSys( wt, bs, t_max )a.xt[0,0] = 0a.xt[1,0] = 1a.make_tms() pylab.clf()pylab.plot(a.xt[0],a.xt[1], 'o-') # Check calculationprint (scipy.dot( a.wt, a.xt[:,0] ) + a.bs) - a.xt[:,1]*/`