python多线程方法详解

作者:python_rser 时间:2023-10-16 02:46:31 

处理多个数据和多文件时,使用for循环的速度非常慢,此时需要用多线程来加速运行进度,常用的模块为multiprocess和joblib,下面对两种包我常用的方法进行说明。

1、模块安装

pip install multiprocessing
pip install joblib

2、以分块计算NDVI为例

首先导入需要的包

import numpy as np
from osgeo import gdal
import time
from multiprocessing import cpu_count
from multiprocessing import Pool
from joblib import Parallel, delayed

定义GdalUtil类,以读取遥感数据

class GdalUtil:
   def __init__(self):
       pass
   @staticmethod
   def read_file(raster_file, read_band=None):
       """读取栅格数据"""
       # 注册栅格驱动
       gdal.AllRegister()
       gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
       # 打开输入图像
       dataset = gdal.Open(raster_file, gdal.GA_ReadOnly)
       if dataset == None:
           print('打开图像{0} 失败.\n', raster_file)
       # 列
       raster_width = dataset.RasterXSize
       # 行
       raster_height = dataset.RasterYSize
       # 读取数据
       if read_band == None:
           data_array = dataset.ReadAsArray(0, 0, raster_width, raster_height)
       else:
           band = dataset.GetRasterBand(read_band)
           data_array = band.ReadAsArray(0, 0, raster_width, raster_height)
       return data_array

@staticmethod
   def read_block_data(dataset, band_num, cols_read, rows_read, start_col=0, start_row=0):
       band = dataset.GetRasterBand(band_num)
       res_data = band.ReadAsArray(start_col, start_row, cols_read, rows_read)
       return res_data

@staticmethod
   def get_raster_band(raster_path):
       # 注册栅格驱动
       gdal.AllRegister()
       gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
       # 打开输入图像
       dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
       if dataset == None:
           print('打开图像{0} 失败.\n', raster_path)
       raster_band = dataset.RasterCount
       return raster_band

@staticmethod
   def get_file_size(raster_path):
       """获取栅格仿射变换参数"""
       # 注册栅格驱动
       gdal.AllRegister()
       gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')

# 打开输入图像
       dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
       if dataset == None:
           print('打开图像{0} 失败.\n', raster_path)
       # 列
       raster_width = dataset.RasterXSize
       # 行
       raster_height = dataset.RasterYSize
       return raster_width, raster_height

@staticmethod
   def get_file_geotransform(raster_path):
       """获取栅格仿射变换参数"""
       # 注册栅格驱动
       gdal.AllRegister()
       gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')

# 打开输入图像
       dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
       if dataset == None:
           print('打开图像{0} 失败.\n', raster_path)

# 获取输入图像仿射变换参数
       input_geotransform = dataset.GetGeoTransform()
       return input_geotransform

@staticmethod
   def get_file_proj(raster_path):
       """获取栅格图像空间参考"""
       # 注册栅格驱动
       gdal.AllRegister()
       gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')

# 打开输入图像
       dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
       if dataset == None:
           print('打开图像{0} 失败.\n', raster_path)

# 获取输入图像空间参考
       input_project = dataset.GetProjection()
       return input_project

@staticmethod
   def write_file(dataset, geotransform, project, output_path, out_format='GTiff', eType=gdal.GDT_Float32):
       """写入栅格"""
       if np.ndim(dataset) == 3:
           out_band, out_rows, out_cols = dataset.shape
       else:
           out_band = 1
           out_rows, out_cols = dataset.shape

# 创建指定输出格式的驱动
       out_driver = gdal.GetDriverByName(out_format)
       if out_driver == None:
           print('格式%s 不支持Creat()方法.\n', out_format)
           return

out_dataset = out_driver.Create(output_path, xsize=out_cols,
                                       ysize=out_rows, bands=out_band,
                                       eType=eType)
       # 设置输出图像的仿射参数
       out_dataset.SetGeoTransform(geotransform)

# 设置输出图像的投影参数
       out_dataset.SetProjection(project)

# 写出数据
       if out_band == 1:
           out_dataset.GetRasterBand(1).WriteArray(dataset)
       else:
           for i in range(out_band):
               out_dataset.GetRasterBand(i + 1).WriteArray(dataset[i])
       del out_dataset

定义计算NDVI的函数

def cal_ndvi(multi):
   '''
   计算高分NDVI
   :param multi:格式为列表,依次包含[遥感文件路径,开始行号,开始列号,待读的行数,待读的列数]
   :return: NDVI数组
   '''
   input_file, start_col, start_row, cols_step, rows_step = multi
   dataset = gdal.Open(input_file, gdal.GA_ReadOnly)
   nir_data = GdalUtil.read_block_data(dataset, 4, cols_step, rows_step, start_col=start_col, start_row=start_row)
   red_data = GdalUtil.read_block_data(dataset, 3, cols_step, rows_step, start_col=start_col, start_row=start_row)
   ndvi = (nir_data - red_data) / (nir_data + red_data)
   ndvi[(ndvi > 1.5) | (ndvi < -1)] = 0
   return ndvi

定义主函数

if __name__ == "__main__":
   input_file = r'D:\originalData\GF1\namucuo2021.tif'
   output_file = r'D:\originalData\GF1\namucuo2021_ndvi.tif'
   method = 'joblib'
   # method = 'multiprocessing'
   # 获取文件主要信息
   raster_cols, raster_rows = GdalUtil.get_file_size(input_file)
   geotransform = GdalUtil.get_file_geotransform(input_file)
   project = GdalUtil.get_file_proj(input_file)
   # 定义分块大小
   rows_block_size = 50
   cols_block_size = 50
   multi = []
   for j in range(0, raster_rows, rows_block_size):
       for i in range(0, raster_cols, cols_block_size):
           if j + rows_block_size < raster_rows:
               rows_step = rows_block_size
           else:
               rows_step = raster_rows - j
           # 数据横向步长
           if i + cols_block_size < raster_cols:
               cols_step = cols_block_size
           else:
               cols_step = raster_cols - i
           temp_multi = [input_file, i, j, cols_step, rows_step]
           multi.append(temp_multi)

t1 = time.time()
   if method == 'multiprocessing':
       # multiprocessing方法
       pool = Pool(processes=cpu_count()-1)
       # 注意map函数中传入的参数应该是可迭代对象,如list;返回值为list
       res = pool.map(cal_ndvi, multi)
       pool.close()
       pool.join()
   else:
       # joblib方法
       res = Parallel(n_jobs=-1)(delayed(cal_ndvi)(input_list) for input_list in multi)

t2 = time.time()
   print("Total time:" + (t2 - t1).__str__())

# 将multiprocessing中的结果提取出来,放回对应的矩阵位置中
   out_data = np.zeros([raster_rows, raster_cols], dtype='float')
   for result, input_multi in zip(res, multi):
       start_col = input_multi[1]
       start_row = input_multi[2]
       cols_step = input_multi[3]
       rows_step = input_multi[4]
       out_data[start_row:start_row + rows_step, start_col:start_col + cols_step] = result

GdalUtil.write_file(out_data, geotransform, project, output_file)

双重for循环时,两层for循环都使用multiprocessing时会报错,这时可以外层for循环使用joblib方法,内层for循环改为multiprocessing方法,不会报错

来源:https://blog.csdn.net/u010562884/article/details/122543534

标签:python,多线程
0
投稿

猜你喜欢

  • 避坑:Sql中 in 和not in中有null值的情况说明

    2024-01-22 17:01:33
  • SQLServer 2000 数据库同步详细步骤[两台服务器]

    2024-01-21 11:18:03
  • oracle数据库在客户端建立dblink语法

    2023-07-14 19:51:23
  • python 删除excel表格重复行,数据预处理操作

    2023-04-19 06:10:33
  • 3个 Python 编程技巧

    2023-11-30 08:05:19
  • JavaScript 与 ActionScript 3.0 交互的一些问题

    2008-01-27 12:20:00
  • python爬虫爬取淘宝商品比价(附淘宝反爬虫机制解决小办法)

    2021-11-14 06:16:40
  • Django中的CBV和FBV示例介绍

    2022-05-23 10:13:59
  • Python OpenCV超详细讲解透视变换的实现

    2021-08-02 21:19:48
  • 利用Python批量导出mysql数据库表结构的操作实例

    2024-01-21 00:41:58
  • Python实现批量自动整理文件

    2023-05-10 21:58:12
  • ASP Google的translate API代码

    2011-04-03 11:16:00
  • 基于Node.js实现nodemailer邮件发送

    2024-05-03 15:36:40
  • python爬取微信公众号文章图片并转为PDF

    2021-02-02 06:53:31
  • 端午节将至,用Python爬取粽子数据并可视化,看看网友喜欢哪种粽子吧!

    2023-08-23 06:29:31
  • 详解让Python性能起飞的15个技巧

    2023-10-13 18:12:27
  • DWCS3-CSS布局之一CSS规则大纲

    2008-06-11 18:48:00
  • Python基础教程之tcp socket编程详解及简单实例

    2021-04-18 12:04:29
  • 学习Python需要哪些工具

    2023-06-20 17:48:40
  • 一行两列背景自适应的简单写法 DIV+CSS

    2008-07-15 12:51:00
  • asp之家 网络编程 m.aspxhome.com