原型
CLASS torch.nn.Hardtanh(min_val=- 1.0, max_val=1.0, inplace=False, min_value=None, max_value=None)
参数
-1
1
False
定义
HardTanh(x)={max_valif x>max_val min_valif x
图
代码
import torch
import torch.nn as nnm = nn.Hardtanh(-2, 2)
input = torch.randn(2)
output = m(input)
print("input: ", input) # input: tensor([2.1926, 0.2211])
print("output: ", output) # output: tensor([2.0000, 0.2211])
Hardtanh — PyTorch 1.13 documentation