pytorch中关于复数的处理

torch中的数据类型与numpy有一点区别,就是他没有complex64和complex128这样的数据类型,就是说在numpy中那些复数的简单表达(a+b*j)在torch中是无法使用的。

然而我们常常要使用fft在频域中处理复数变量的。

torch中对复数的处理跟c非常类似(一维数组变成二维数组)。其实也非常简单,解决的方法就是设计把矩阵中的数据一维度化,然后再加上一个维度,新的维度用来存储复数域,然后我们再对其做各种操作。

举个例子:一个shape=[n,n]的矩阵,我们先reshape成[n2n^2n2]的一个维度的东西,然后用unsqueeze变成[n2n^2n2,1]的一个东西,我们将之作为一个数组的实部(real)。然后我们建立一个相同大小的数组作为虚部(img)。然后使用torch.cat((real,img),1)将他们拼接起来。

接下来,我们设计一个斜坡滤波器(RL滤波器)来实现高通滤波,这其中对于复数的处理是非常关键的。

我把主要的东西写了一下,细节的东西还没有完善

#复数乘法
def complexMulti(a,b):
 r = a.shape[0]
 c = torch.zeros([r,2])
 for i in range(r):
  c[i,0] = a[i,0]*b[i,0]-a[i,1]*b[i,1]
  c[i,1] = a[i,0]*b[i,1]+a[i,1]*b[i,0]
 return c
#1.RL滤波器
def RLfilter(pic):
 r,c = pic.shape
 f = torch.zeros(c)
 n = c//2
 if c%2 == 0:
  for i in range(n):
   f[i] = i
   f[n+i] = n-i   
 else:
  f[n] = n
  for i in range(1,n+1):
   f[n-i] = n-i
   f[n+i] = n-i
 return f/n
#滤波处理
def filterDotTensor(pic,fil):
 r,c = pic.shape
 fpic = torch.zeros([r,c])
 newpic = torch.zeros([r,c])
 com = torch.ones(c)
 com = torch.unsqueeze(com,1).to(torch.float64)
 fil = torch.unsqueeze(fil,1).to(torch.float64)
 fil = torch.cat((fil,com),1)
 for k in range(180):
  newLine = torch.unsqueeze(pic[k,:],1)
  newLine = torch.cat((newLine,com),1)
  newLine = torch.fft(newLine,1)
  fnew = torch.squeeze(newLine,1)
  fnew = complexMulti(fil,fnew)#复数乘法
  fnew = torch.ifft(fnew,1)
  newpic[k,:] = fnew[:,0]#只取实部
 return newpic
Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐