跳转到主要内容

Documentation Index

Fetch the complete documentation index at: https://dripart-mintlify-e28287af.mintlify.app/llms.txt

Use this file to discover all available pages before exploring further.

pytorch、张量与 torch.Tensor

Comfy 的所有核心数值计算都是由 pytorch 完成的。如果你的自定义节点需要深入 stable diffusion 的底层,你就需要熟悉这个库,这远超本简介的范围。 不过,许多自定义节点都需要操作图像、潜变量和蒙版,这些在内部都表示为 torch.Tensor,因此你可能需要收藏 torch.Tensor 的官方文档

什么是张量?

torch.Tensor 表示张量,张量是向量或矩阵在任意维度上的数学泛化。张量的 (rank)是它的维度数量(所以向量秩为 1,矩阵秩为 2);它的 形状(shape)描述了每个维度的大小。 因此,一个 RGB 图像(高为 H,宽为 W)可以被看作是三组数组(每个颜色通道一组),每组大小为 H x W,可以表示为形状为 [H,W,3] 的张量。在 Comfy 中,图像几乎总是以批量(batch)形式出现(即使批量中只有一张图)。torch 总是将批量维放在第一位,所以 Comfy 的图像形状为 [B,H,W,3],通常写作 [B,H,W,C],其中 C 代表通道数(Channels)。

squeeze、unsqueeze 与 reshape

如果张量的某个维度大小为 1(称为折叠维度),那么去掉这个维度后的张量与原张量等价(比如只有一张图片的批量其实就是一张图片)。去除这种折叠维度称为 squeeze,插入一个这样的维度称为 unsqueeze。
有些 torch 代码和自定义节点作者会在某个维度折叠时返回 squeeze 过的张量——比如批量只有一个成员时。这是常见的 bug 来源!
将同样的数据以不同的形状表示称为 reshape。通常你需要了解底层数据结构,因此请谨慎操作!

重要符号说明

torch.Tensor 支持大多数 Python 的切片符号、迭代和其他常见的类列表操作。张量还有一个 .shape 属性,返回其大小,类型为 torch.Size(它是 tuple 的子类,可以当作元组使用)。 还有一些你经常会见到的重要符号(其中几个在标准 Python 里不常见,但在处理张量时很常用):
  • torch.Tensor 支持在切片符号中使用 None,表示插入一个大小为 1 的新维度。
  • : 在切片张量时常用,表示”保留整个维度”。就像 Python 里的 a[start:end],但省略了起止点。
  • ... 表示”未指定数量的所有维度”。所以 a[0, ...] 会提取批量中的第一个元素,无论有多少维度。
  • 在需要传递形状的函数中,形状通常以 tuple 形式传递,其中某个维度可以用 -1,表示该维度的大小由数据总量自动推算。
>>> a = torch.Tensor((1,2))
>>> a.shape
torch.Size([2])
>>> a[:,None].shape 
torch.Size([2, 1])
>>> a.reshape((1,-1)).shape
torch.Size([1, 2])

元素级操作

许多 torch.Tensor 的二元操作(包括 ’+’, ’-’, ’*’, ’/’ 和 ’==‘)都是元素级的(即对每个元素独立操作)。操作数必须是形状相同的两个张量,或一个张量和一个标量。所以:
>>> import torch
>>> a = torch.Tensor((1,2))
>>> b = torch.Tensor((3,2))
>>> a*b
tensor([3., 4.])
>>> a/b
tensor([0.3333, 1.0000])
>>> a==b
tensor([False,  True])
>>> a==1
tensor([ True, False])
>>> c = torch.Tensor((3,2,1)) 
>>> a==c
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0

张量的布尔值

张量的”真值”与 Python 列表的真值不同。
你可能熟悉 Python 列表的真值:非空列表为 TrueNone[]False。而 torch.Tensor(只要有多个元素)没有定义的真值。你需要用 .all().any() 来合并元素级的真值:
>>> a = torch.Tensor((1,2))
>>> print("yes" if a else "no")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
>>> a.all()
tensor(False)
>>> a.any()
tensor(True)
这也意味着你需要用 if a is not None: 而不是 if a: 来判断一个张量变量是否已被赋值。