Skip to content
Snippets Groups Projects
Commit 5308303056ee authored by Kurt Smith's avatar Kurt Smith
Browse files

cleanup in MemoryViewSliceType

parent 3da81d70f961
Branches
No related tags found
No related merge requests found
...@@ -302,7 +302,42 @@ ...@@ -302,7 +302,42 @@
} }
''' '''
def get_copy_contents_code(from_mvs, to_mvs, cfunc_name): def memoryviewslice_get_copy_func(from_memview, to_memview, mode, scope):
from PyrexTypes import CFuncType, CFuncTypeArg
if mode == 'c':
cython_name = "copy"
copy_name = '__Pyx_BufferNew_C_From_'+from_memview.specialization_suffix()
contig_flag = 'PyBUF_C_CONTIGUOUS'
elif mode == 'fortran':
cython_name = "copy_fortran"
copy_name = "__Pyx_BufferNew_F_From_"+from_memview.specialization_suffix()
contig_flag = 'PyBUF_F_CONTIGUOUS'
else:
assert False
copy_contents_name = get_copy_contents_name(from_memview, to_memview)
scope.declare_cfunction(cython_name,
CFuncType(from_memview,
[CFuncTypeArg("memviewslice", from_memview, None)]),
pos = None,
defining = 1,
cname = copy_name)
copy_impl = copy_template % dict(
copy_name=copy_name,
mode=mode,
sizeof_dtype="sizeof(%s)" % from_memview.dtype.declaration_code(''),
contig_flag=contig_flag,
copy_contents_name=copy_contents_name)
copy_decl = ("static __Pyx_memviewslice "
"%s(const __Pyx_memviewslice); /* proto */\n" % (copy_name,))
return (copy_decl, copy_impl)
def get_copy_contents_func(from_mvs, to_mvs, cfunc_name):
assert from_mvs.dtype == to_mvs.dtype assert from_mvs.dtype == to_mvs.dtype
assert len(from_mvs.axes) == len(to_mvs.axes) assert len(from_mvs.axes) == len(to_mvs.axes)
...@@ -313,7 +348,10 @@ ...@@ -313,7 +348,10 @@
if access != 'direct': if access != 'direct':
raise NotImplementedError("only direct access supported currently.") raise NotImplementedError("only direct access supported currently.")
code = ''' code_decl = ("static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs,"
"__Pyx_memviewslice *to_mvs); /* proto */" % {'cfunc_name' : cfunc_name})
code_impl = '''
static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice *to_mvs) { static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice *to_mvs) {
...@@ -338,7 +376,7 @@ ...@@ -338,7 +376,7 @@
# 'i' always goes up from zero to ndim-1. # 'i' always goes up from zero to ndim-1.
# 'idx' is the same as 'i' for c_contig, and goes from ndim-1 to 0 for f_contig. # 'idx' is the same as 'i' for c_contig, and goes from ndim-1 to 0 for f_contig.
# this makes the loop code below identical in both cases. # this makes the loop code below identical in both cases.
code += INDENT+"Py_ssize_t i%d = 0, idx%d = 0;\n" % (i,i) code_impl += INDENT+"Py_ssize_t i%d = 0, idx%d = 0;\n" % (i,i)
code += INDENT+"Py_ssize_t stride%(i)d = from_mvs->diminfo[%(idx)d].strides;\n" % {'i':i, 'idx':idx} code_impl += INDENT+"Py_ssize_t stride%(i)d = from_mvs->diminfo[%(idx)d].strides;\n" % {'i':i, 'idx':idx}
code += INDENT+"Py_ssize_t shape%(i)d = from_mvs->diminfo[%(idx)d].shape;\n" % {'i':i, 'idx':idx} code_impl += INDENT+"Py_ssize_t shape%(i)d = from_mvs->diminfo[%(idx)d].shape;\n" % {'i':i, 'idx':idx}
...@@ -344,6 +382,6 @@ ...@@ -344,6 +382,6 @@
code += "\n" code_impl += "\n"
# put down the nested for-loop. # put down the nested for-loop.
for k in range(ndim): for k in range(ndim):
...@@ -346,6 +384,6 @@ ...@@ -346,6 +384,6 @@
# put down the nested for-loop. # put down the nested for-loop.
for k in range(ndim): for k in range(ndim):
code += INDENT*(k+1) + "for(i%(k)d=0; i%(k)d<shape%(k)d; i%(k)d++) {\n" % {'k' : k} code_impl += INDENT*(k+1) + "for(i%(k)d=0; i%(k)d<shape%(k)d; i%(k)d++) {\n" % {'k' : k}
if k >= 1: if k >= 1:
...@@ -351,3 +389,3 @@ ...@@ -351,3 +389,3 @@
if k >= 1: if k >= 1:
code += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d + idx%(km1)d;\n" % {'k' : k, 'km1' : k-1} code_impl += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d + idx%(km1)d;\n" % {'k' : k, 'km1' : k-1}
else: else:
...@@ -353,6 +391,6 @@ ...@@ -353,6 +391,6 @@
else: else:
code += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d;\n" % {'k' : k} code_impl += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d;\n" % {'k' : k}
# the inner part of the loop. # the inner part of the loop.
dtype_decl = from_mvs.dtype.declaration_code("") dtype_decl = from_mvs.dtype.declaration_code("")
last_idx = ndim-1 last_idx = ndim-1
...@@ -355,9 +393,9 @@ ...@@ -355,9 +393,9 @@
# the inner part of the loop. # the inner part of the loop.
dtype_decl = from_mvs.dtype.declaration_code("") dtype_decl = from_mvs.dtype.declaration_code("")
last_idx = ndim-1 last_idx = ndim-1
code += INDENT*ndim+"memcpy(to_buf, from_buf+idx%(last_idx)d, sizeof(%(dtype_decl)s));\n" % locals() code_impl += INDENT*ndim+"memcpy(to_buf, from_buf+idx%(last_idx)d, sizeof(%(dtype_decl)s));\n" % locals()
code += INDENT*ndim+"to_buf += sizeof(%(dtype_decl)s);\n" % locals() code_impl += INDENT*ndim+"to_buf += sizeof(%(dtype_decl)s);\n" % locals()
# for-loop closing braces # for-loop closing braces
for k in range(ndim-1, -1, -1): for k in range(ndim-1, -1, -1):
...@@ -361,6 +399,6 @@ ...@@ -361,6 +399,6 @@
# for-loop closing braces # for-loop closing braces
for k in range(ndim-1, -1, -1): for k in range(ndim-1, -1, -1):
code += INDENT*(k+1)+"}\n" code_impl += INDENT*(k+1)+"}\n"
# init to_mvs->data and to_mvs->diminfo. # init to_mvs->data and to_mvs->diminfo.
...@@ -365,9 +403,9 @@ ...@@ -365,9 +403,9 @@
# init to_mvs->data and to_mvs->diminfo. # init to_mvs->data and to_mvs->diminfo.
code += INDENT+"temp_memview = to_mvs->memview;\n" code_impl += INDENT+"temp_memview = to_mvs->memview;\n"
code += INDENT+"temp_data = to_mvs->data;\n" code_impl += INDENT+"temp_data = to_mvs->data;\n"
code += INDENT+"to_mvs->memview = 0; to_mvs->data = 0;\n" code_impl += INDENT+"to_mvs->memview = 0; to_mvs->data = 0;\n"
code += INDENT+"if(unlikely(-1 == __Pyx_init_memviewslice(temp_memview, %d, to_mvs))) {\n" % (ndim,) code_impl += INDENT+"if(unlikely(-1 == __Pyx_init_memviewslice(temp_memview, %d, to_mvs))) {\n" % (ndim,)
code += INDENT*2+"return -1;\n" code_impl += INDENT*2+"return -1;\n"
code += INDENT+"}\n" code_impl += INDENT+"}\n"
...@@ -373,3 +411,3 @@ ...@@ -373,3 +411,3 @@
code += INDENT + "return 0;\n" code_impl += INDENT + "return 0;\n"
...@@ -375,3 +413,3 @@ ...@@ -375,3 +413,3 @@
code += '}\n' code_impl += '}\n'
...@@ -377,5 +415,5 @@ ...@@ -377,5 +415,5 @@
return code return code_decl, code_impl
def get_axes_specs(env, axes): def get_axes_specs(env, axes):
''' '''
......
...@@ -317,12 +317,8 @@ ...@@ -317,12 +317,8 @@
to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c, self.env) to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c, self.env)
to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f, self.env) to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f, self.env)
cython_name_c = 'copy' copy_contents_name_c =\
cython_name_f = 'copy_fortran' MemoryView.get_copy_contents_name(self, to_memview_c)
copy_contents_name_f =\
copy_name_c = '__Pyx_BufferNew_C_From_'+self.specialization_suffix() MemoryView.get_copy_contents_name(self, to_memview_f)
copy_name_f = '__Pyx_BufferNew_F_From_'+self.specialization_suffix()
c_copy_util_code = UtilityCode()
f_copy_util_code = UtilityCode()
...@@ -328,14 +324,6 @@ ...@@ -328,14 +324,6 @@
for (to_memview, copy_name, cython_name, mode, contig_flag, util_code) in ( c_copy_decl, c_copy_impl = \
(to_memview_c, copy_name_c, cython_name_c, 'c', 'PyBUF_C_CONTIGUOUS', c_copy_util_code), MemoryView.memoryviewslice_get_copy_func(self, to_memview_c, 'c', self.scope)
(to_memview_f, copy_name_f, cython_name_f, 'fortran', 'PyBUF_F_CONTIGUOUS', f_copy_util_code)): f_copy_decl, f_copy_impl = \
MemoryView.memoryviewslice_get_copy_func(self, to_memview_f, 'fortran', self.scope)
copy_contents_name = MemoryView.get_copy_contents_name(self, to_memview)
scope.declare_cfunction(cython_name,
CFuncType(self,
[CFuncTypeArg("memviewslice", self, None)]),
pos = None,
defining = 1,
cname = copy_name)
...@@ -341,12 +329,8 @@ ...@@ -341,12 +329,8 @@
copy_impl = MemoryView.copy_template %\ c_copy_contents_decl, c_copy_contents_impl = \
dict(copy_name=copy_name, MemoryView.get_copy_contents_func(
mode=mode, self, to_memview_c, copy_contents_name_c)
sizeof_dtype="sizeof(%s)" % self.dtype.declaration_code(''), f_copy_contents_decl, f_copy_contents_impl = \
contig_flag=contig_flag, MemoryView.get_copy_contents_func(
copy_contents_name=copy_contents_name) self, to_memview_f, copy_contents_name_f)
copy_decl = '''\
static __Pyx_memviewslice %s(const __Pyx_memviewslice); /* proto */
''' % (copy_name,)
...@@ -352,16 +336,9 @@ ...@@ -352,16 +336,9 @@
util_code.proto = copy_decl c_util_code = UtilityCode(
util_code.impl = copy_impl proto = "%s%s" % (c_copy_decl, c_copy_contents_decl),
impl = "%s%s" % (c_copy_impl, c_copy_contents_impl))
copy_contents_name_c = MemoryView.get_copy_contents_name(self, to_memview_c) f_util_code = UtilityCode(
copy_contents_name_f = MemoryView.get_copy_contents_name(self, to_memview_f) proto = f_copy_decl,
impl = f_copy_impl)
c_copy_util_code.proto += ('static int %s'
'(const __Pyx_memviewslice *,'
' __Pyx_memviewslice *); /* proto */\n' %
(copy_contents_name_c,))
c_copy_util_code.impl += \
MemoryView.get_copy_contents_code(self, to_memview_c, copy_contents_name_c)
if copy_contents_name_c != copy_contents_name_f: if copy_contents_name_c != copy_contents_name_f:
...@@ -366,8 +343,5 @@ ...@@ -366,8 +343,5 @@
if copy_contents_name_c != copy_contents_name_f: if copy_contents_name_c != copy_contents_name_f:
f_util_code.proto += f_copy_contents_decl
f_copy_util_code.proto += ('static int %s' f_util_code.impl += f_copy_contents_impl
'(const __Pyx_memviewslice *,'
' __Pyx_memviewslice *); /* proto */\n' %
(copy_contents_name_f,))
...@@ -373,8 +347,5 @@ ...@@ -373,8 +347,5 @@
f_copy_util_code.impl += \ c_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if c_util_code.proto == util_code.proto]
MemoryView.get_copy_contents_code(self, to_memview_f, copy_contents_name_f) f_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if f_util_code.proto == util_code.proto]
c_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if c_copy_util_code.proto == util_code.proto]
f_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if f_copy_util_code.proto == util_code.proto]
if not c_copy_used: if not c_copy_used:
...@@ -379,5 +350,5 @@ ...@@ -379,5 +350,5 @@
if not c_copy_used: if not c_copy_used:
self.env.use_utility_code(c_copy_util_code) self.env.use_utility_code(c_util_code)
if not f_copy_used: if not f_copy_used:
...@@ -382,6 +353,6 @@ ...@@ -382,6 +353,6 @@
if not f_copy_used: if not f_copy_used:
self.env.use_utility_code(f_copy_util_code) self.env.use_utility_code(f_util_code)
# is_c_contiguous and is_f_contiguous functions # is_c_contiguous and is_f_contiguous functions
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment