def accuracy(y_hat, y): #对于一个分类问题,我们想计算的是它的准确率,他并不是一个回归函数,所以我们需要的是计算它的准确率"""计算预测正确的数量。"""#现在 y_hat是一个256*10的一个矩阵if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1) # 因为使用了softmax,所以我们的得到的是一个概率分布,我们取值最大的,也就是预测的概率值最大的索引,对应真实值的缩影cmp = pe(y.dtype) == y #因为数据类型可能不同,这里只是记录一个boolean值print(cmp)print(cmp.shape)return pe(y.dtype).sum())
比如在softmax这个例子中
y_hat是一个 25610的矩阵,这10代表着,模型预测这个样本可能是哪一类的概率,我们用axis取出一个点,这个点代表着,最有可能的点,然后记录下它的索引,
y_hat现在变成了一个2561的矩阵,对比y(label)如果是一个值返回True,如果不是就返回False
因为担心y_hat和y的数据类型不一样,先做强制类型转换
再最后,加权y_hat
class Accumulator:#在这里创建一个累加器的类"""在`n`个变量上累加。"""def __init__(self, n):self.data = [0.0] * n # 创建一个一行n列的矩阵,或者说向量def add(self, *args):print(args)self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx): #定义这个函数它的实例对象(假设为P)就可以这样P[key]取值return self.data[idx]
非常之简单,这就是一个非常简单的累加器。初始化的时候要记录我们需要几个记录的点。比如训练的时候要记录总损失,总预测对的数量,总预测数量,最后进行一个使用。不用理会add,add就是累加一轮的损失,预测对的数量,总预测数量
def evaluate_accuracy(net, data_iter):# 在这里我们定义一个评估函数"""计算在指定数据集上模型的精度。"""if isinstance(net, Module):net.eval() # 将模型设置为评估模式metric = Accumulator(2) # 正确预测数、预测总数for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel()) #accuracy的第一个参数是X,第二个参数是y,这句话的意思就是,return metric[0] / metric[1]
对于测试集,我们值关心他的准确率
class Animator:"""在动画中绘制数据。"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数fig_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# 向图表中添加多个数据点if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fig_axes()display.display(self.fig)display.clear_output(wait=True)
对于这个类,你只需要只要,你传入epho迭代的次数和 (训练损失,准确率),(测试损失,测试准确率)就ok了
本文发布于:2024-01-27 23:59:50,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/17063712023429.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |