算子编译期参数检查

2025-01-22  本文已影响0人  Joe_WQ

说明

pytorch的底层实现是用的c++,导致检查type极其麻烦,需要jit、template之类的技术,但是zig是支持type类型的,comptime如虎添翼。

简单的函数名和参数检查

注:以下代码是deepseek r1 生成的,然后改了下bug,在 0.14-dev下能跑起来。

const std = @import("std");

const Tensor = struct {};
// 定义 Tensor 结构体示例
const Dpu = struct {
    // 示例方法,符合要求的签名
    pub fn add(self: *Tensor, dim: i32, index: *Tensor, src: *Tensor) Tensor {
        std.debug.print("add {}, {}, {}, {}\n", .{ self, dim, index, src });
        return Tensor{};
    }
};

// 编译时检查方法签名的函数
fn checkMethodSignature(
    comptime Struct: type,
    comptime methodName: []const u8,
    comptime expectedParamTypes: []const type,
    comptime expectedReturnType: type,
) void {
    // 检查方法是否存在于结构体中
    if (!@hasDecl(Struct, methodName)) {
        @compileError("方法 '" ++ methodName ++ "' 不存在于 " ++ @typeName(Struct));
    }

    // 获取方法实例及其类型信息
    const method = @field(Struct, methodName);
    const FuncType = @TypeOf(method);
    const funcInfo = @typeInfo(FuncType).@"fn";

    // 检查参数数量(包含 self)
    const expectedParamCount = expectedParamTypes.len;
    if (funcInfo.params.len != expectedParamCount) {
        @compileError("参数数量错误,期望 " ++ std.fmt.comptimePrint("{}", .{expectedParamCount}) ++ " 个,实际 " ++ std.fmt.comptimePrint("{}", .{funcInfo.params.len}));
    }

    // 检查参数类型
    for (funcInfo.params[0..], 0..) |param, i| {
        const expectedType = expectedParamTypes[i];
        if (param.type != expectedType) {
            @compileError("参数 " ++ std.fmt.comptimePrint("{}", .{i}) ++ " 类型应为 " ++ @typeName(expectedType) ++ ",实际为 " ++ @typeName(param.type.?));
        }
    }

    // 检查返回类型
    if (funcInfo.return_type != expectedReturnType) {
        @compileError("返回类型应为 " ++ @typeName(expectedReturnType) ++ ",实际为 " ++ @typeName(funcInfo.return_type.?));
    }
}

// 编译时执行检查
comptime {
    checkMethodSignature(Dpu, "add", &[_]type{ *Tensor, i32, *Tensor, *Tensor }, Tensor);
}

pub fn main() void {
    std.debug.print("函数签名检查通过!\n", .{});
}

上一篇 下一篇

猜你喜欢

热点阅读