网站首页 文章专栏 python matplotlib函数备忘
python matplotlib函数备忘
创建于:2018-03-23 16:00:00 更新于:2025-01-18 10:05:47 羽瀚尘 1849
python python

1 简介

matplotlib是python中一个非常好用的画图库,倾向于使用数据画图,设计思路与matlab中的plot相同。

1.1 画图与保存

1.1.1 无显示器画图

ssh远程操作 出现RuntimeError: Invalid DISPLAY variable 添加如下代码

plt.switch_backend('agg')

参考页面

1.1.2 保存图像

使用plt.savefig

import matplotlib.pyplot as plt
plt.savefig("filename.png")
plt.show()

注意savefig必须在show之前调用,否则show之后默认开新图,保存的图一片空白

或者,使用gcf方法

fig = plt.gcf()
plt.show()
fig1.savefig('test.jpg', dpi=100)

1.1.3 图像格式

在plt.savefig()方法中增加format=参数 可选的参数如下: - jpg - png - pdf - eps - svg

完整的调用方法为 plt.savefig('file_name', format='jpg') 如果不指定format,默认为jpg格式,与文件的后缀名无关

1.1.4 设置图像dpi

plt.savefig(..., dpi=150)

1.1.5 直接获取bin图像流

在服务器环境,或者特定环境下,我们不建议用文件来交换画图结果,更希望函数直接返回一个二进制的图像文件。

import matplotlib.pyplot as plt
import io
from PIL import Image 
# ... plot something
canvas = plt.get_current_fig_manager().canvas
canvas.draw()
buf, size = canvas.print_to_buffer()
image = Image.frombuffer('RGBA', size, buf, 'raw', 'RGBA', 0, 1)
buffer = io.BytesIO()
image.save(buffer,'PNG')
data = buffer.getvalue()

1.2 画图增强

1.2.1 画多个子图

共享x y轴的意思是,多张图是否使用同一个单位刻度,共享后只会在最左边的y轴和最下边的 x轴标出数字,其他轴只有单位刻度。

import numpy as np 
import matplotlib.pyplot as plt 

x = np.linspace(0, 2*np.pi, 400)
y = np.sin(x**2)

# 不共享y轴
f, (ax1, ax2) = plt.subplots(1, 2, sharey=False)
ax1.plot(x, y)
ax1.set_title('Not sharing Y axis')
ax2.scatter(x, y)

# 共享y轴
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
ax1.plot(x, y)
ax1.set_title('Sharing Y axis')
ax2.scatter(x, y)


plt.show()

Creates four polar axes, and accesses them through the returned array

>>> fig, axes = plt.subplots(2, 2, subplot_kw=dict(polar=True))
>>> axes[0, 0].plot(x, y)
>>> axes[1, 1].scatter(x, y)

1.2.2 增加子图的title

ax.set_title('Simple plot')

参考

1.2.3 画grid

 plt.grid(True, color='grey', linestyle='-', linewidth=1)
 ```

 ### 1.2.4 关闭坐标刻度
 ```py
 plt.xticks([])
 plt.yticks([])
 ```

 ### 1.2.5 关闭坐标轴
 整个坐标系统都不见了,只剩下曲线
 ```py
 plt.axis('off')
 ```



 ### 1.2.6 坐标轴不可见
 ```py
 frame = plt.gca()
 frame.axes.get_yaxis().set_visible(False)
 frame.axes.get_xaxis().set_visible(False)
 ```
### 1.2.7 画图例

```py
plt.figure()
plt.plot(data['loss100'])
plt.plot(data['loss300'])
plt.plot(data['loss600'])

plt.title("Compare loss(different lr)") 
plt.xlabel("step")
plt.ylabel("loss")
plt.legend(['lr=0.01', 'lr=0.3', 'lr=0.6'])

画图例参考(这个是官方教程,但是无法画出图例,姑且列上等以后研究)

2. 其他常用常数

2.1 格式化输出

  1. 使用百分号
print('hello %s%.4f' %('str', 5.0))
  1. 使用format
print('hello {:.4f}/{:.5f} {}'.format(5,6,'str'))

2.2 zip函数

zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。

如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同,利用 * 号操作符,可以将元组解压为列表。

>>>a = [1,2,3]
>>> b = [4,5,6]
>>> c = [4,5,6,7,8]
>>> zipped = zip(a,b)     # 打包为元组的列表
[(1, 4), (2, 5), (3, 6)]
>>> zip(a,c)              # 元素个数与最短的列表一致
[(1, 4), (2, 5), (3, 6)]
>>> zip(*zipped)          # 与 zip 相反,可理解为解压,返回二维矩阵式
[(1, 2, 3), (4, 5, 6)]