blob: 888d70f305eff5fd33a95cb71ffa6e9a2ac38ff4 [file] [edit]
# Copyright (C) Microsoft Corporation. All rights reserved.
# This file is distributed under the University of Illinois Open Source License. See LICENSE.TXT for details.
import argparse
import functools
import collections
from hctdb import *
# get db singletons
g_db_dxil = None
def get_db_dxil():
global g_db_dxil
if g_db_dxil is None:
g_db_dxil = db_dxil()
return g_db_dxil
g_db_hlsl = None
def get_db_hlsl():
global g_db_hlsl
if g_db_hlsl is None:
thisdir = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(thisdir, "gen_intrin_main.txt"), "r") as f:
g_db_hlsl = db_hlsl(f)
return g_db_hlsl
def format_comment(prefix, val):
"Formats a value with a line-comment prefix."
result = ""
line_width = 80
content_width = line_width - len(prefix)
l = len(val)
while l:
if l < content_width:
result += prefix + val.strip()
result += "\n"
l = 0
else:
split_idx = val.rfind(" ", 0, content_width)
result += prefix + val[:split_idx].strip()
result += "\n"
val = val[split_idx + 1 :]
l = len(val)
return result
def format_rst_table(list_of_tuples):
"Produces a reStructuredText simple table from the specified list of tuples."
# Calculate widths.
widths = None
for t in list_of_tuples:
if widths is None:
widths = [0] * len(t)
for i, v in enumerate(t):
widths[i] = max(widths[i], len(str(v)))
# Build banner line.
banner = ""
for i, w in enumerate(widths):
if i > 0:
banner += " "
banner += "=" * w
banner += "\n"
# Build the result.
result = banner
for i, t in enumerate(list_of_tuples):
for j, v in enumerate(t):
if j > 0:
result += " "
result += str(v)
result += " " * (widths[j] - len(str(v)))
result = result.rstrip()
result += "\n"
if i == 0:
result += banner
result += banner
return result
def build_range_tuples(i):
"Produces a list of tuples with contiguous ranges in the input list."
i = sorted(i)
low_bound = None
high_bound = None
for val in i:
if low_bound is None:
low_bound = val
high_bound = val
else:
assert not high_bound is None
if val == high_bound + 1:
high_bound = val
else:
yield (low_bound, high_bound)
low_bound = val
high_bound = val
if not low_bound is None:
yield (low_bound, high_bound)
def build_range_code(var, i):
"Produces a fragment of code that tests whether the variable name matches values in the given range."
ranges = build_range_tuples(i)
result = ""
for r in ranges:
if r[0] == r[1]:
cond = var + " == " + str(r[0])
else:
cond = "(%d <= %s && %s <= %d)" % (r[0], var, var, r[1])
if result == "":
result = cond
else:
result = result + " || " + cond
return result
class db_docsref_gen:
"A generator of reference documentation."
def __init__(self, db):
self.db = db
instrs = [i for i in self.db.instr if i.is_dxil_op]
instrs = sorted(
instrs,
key=lambda v: ("" if v.category == None else v.category) + "." + v.name,
)
self.instrs = instrs
val_rules = sorted(
db.val_rules,
key=lambda v: ("" if v.category == None else v.category) + "." + v.name,
)
self.val_rules = val_rules
def print_content(self):
self.print_header()
self.print_body()
self.print_footer()
def print_header(self):
print("<!DOCTYPE html>")
print("<html><head><title>DXIL Reference</title>")
print("<style>body { font-family: Verdana; font-size: small; }</style>")
print("</head><body><h1>DXIL Reference</h1>")
self.print_toc("Instructions", "i", self.instrs)
self.print_toc("Rules", "r", self.val_rules)
def print_body(self):
self.print_instruction_details()
self.print_valrule_details()
def print_instruction_details(self):
print("<h2>Instruction Details</h2>")
for i in self.instrs:
print("<h3><a name='i%s'>%s</a></h3>" % (i.name, i.name))
print("<div>Opcode: %d. This instruction %s.</div>" % (i.dxil_opid, i.doc))
if i.remarks:
# This is likely a .rst fragment, but this will do for now.
print("<div> " + i.remarks + "</div>")
print("<div>Operands:</div>")
print("<ul>")
for o in i.ops:
if o.pos == 0:
print("<li>result: %s - %s</li>" % (o.llvm_type, o.doc))
else:
enum_desc = (
""
if o.enum_name == ""
else " one of %s: %s"
% (
o.enum_name,
",".join(db.enum_idx[o.enum_name].value_names()),
)
)
print(
"<li>%d - %s: %s%s%s</li>"
% (
o.pos - 1,
o.name,
o.llvm_type,
"" if o.doc == "" else " - " + o.doc,
enum_desc,
)
)
print("</ul>")
print("<div><a href='#Instructions'>(top)</a></div>")
def print_valrule_details(self):
print("<h2>Rule Details</h2>")
for i in self.val_rules:
print("<h3><a name='r%s'>%s</a></h3>" % (i.name, i.name))
print("<div>" + i.doc + "</div>")
print("<div><a href='#Rules'>(top)</a></div>")
def print_toc(self, name, aprefix, values):
print("<h2><a name='" + name + "'>" + name + "</a></h2>")
last_category = ""
for i in values:
if i.category != last_category:
if last_category != None:
print("</ul>")
print("<div><b>%s</b></div><ul>" % i.category)
last_category = i.category
print("<li><a href='#" + aprefix + "%s'>%s</a></li>" % (i.name, i.name))
print("</ul>")
def print_footer(self):
print("</body></html>")
class db_instrhelp_gen:
"A generator of instruction helper classes."
def __init__(self, db):
self.db = db
TypeInfo = collections.namedtuple("TypeInfo", "name bits")
self.llvm_type_map = {
"i1": TypeInfo("bool", 1),
"i8": TypeInfo("int8_t", 8),
"u8": TypeInfo("uint8_t", 8),
"i32": TypeInfo("int32_t", 32),
"u32": TypeInfo("uint32_t", 32),
}
self.IsDxilOpFuncCallInst = "hlsl::OP::IsDxilOpFuncCallInst"
def print_content(self):
self.print_header()
self.print_body()
self.print_footer()
def print_header(self):
print(
"///////////////////////////////////////////////////////////////////////////////"
)
print(
"// //"
)
print(
"// Copyright (C) Microsoft Corporation. All rights reserved. //"
)
print(
"// DxilInstructions.h //"
)
print(
"// //"
)
print(
"// This file provides a library of instruction helper classes. //"
)
print(
"// //"
)
print(
"// MUCH WORK YET TO BE DONE - EXPECT THIS WILL CHANGE - GENERATED FILE //"
)
print(
"// //"
)
print(
"///////////////////////////////////////////////////////////////////////////////"
)
print("")
print("// TODO: add correct include directives")
print("// TODO: add accessors with values")
print("// TODO: add validation support code, including calling into right fn")
print("// TODO: add type hierarchy")
print("namespace hlsl {")
def bool_lit(self, val):
return "true" if val else "false"
def op_type(self, o):
if o.llvm_type in self.llvm_type_map:
return self.llvm_type_map[o.llvm_type].name
raise ValueError(
"Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name)
)
def op_size(self, o):
if o.llvm_type in self.llvm_type_map:
return self.llvm_type_map[o.llvm_type].bits
raise ValueError(
"Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name)
)
def op_const_expr(self, o):
return (
"(%s)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(%d))->getZExtValue())"
% (self.op_type(o), o.pos - 1)
)
def op_set_const_expr(self, o):
type_size = self.op_size(o)
return (
"llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), %d), llvm::APInt(%d, (uint64_t)val))"
% (type_size, type_size)
)
def print_body(self):
for i in self.db.instr:
if i.is_reserved:
continue
if i.inst_helper_prefix:
struct_name = "%s_%s" % (i.inst_helper_prefix, i.name)
elif i.is_dxil_op:
struct_name = "DxilInst_%s" % i.name
else:
struct_name = "LlvmInst_%s" % i.name
if i.doc:
print("/// This instruction %s" % i.doc)
print("struct %s {" % struct_name)
print(" llvm::Instruction *Instr;")
print(" // Construction and identification")
print(" %s(llvm::Instruction *pInstr) : Instr(pInstr) {}" % struct_name)
print(" operator bool() const {")
if i.is_dxil_op:
op_name = i.fully_qualified_name()
print(
" return %s(Instr, %s);" % (self.IsDxilOpFuncCallInst, op_name)
)
else:
print(
" return Instr->getOpcode() == llvm::Instruction::%s;" % i.name
)
print(" }")
print(" // Validation support")
print(
" bool isAllowed() const { return %s; }" % self.bool_lit(i.is_allowed)
)
if i.is_dxil_op:
print(" bool isArgumentListValid() const {")
print(
" if (%d != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;"
% (len(i.ops) - 1)
)
print(" return true;")
# TODO - check operand types
print(" }")
print(" // Metadata")
print(
" bool requiresUniformInputs() const { return %s; }"
% self.bool_lit(i.requires_uniform_inputs)
)
EnumWritten = False
for o in i.ops:
if o.pos > 1: # 0 is return type, 1 is DXIL OP id
if not EnumWritten:
print(" // Operand indexes")
print(" enum OperandIdx {")
EnumWritten = True
print(" arg_%s = %d," % (o.name, o.pos - 1))
if EnumWritten:
print(" };")
AccessorsWritten = False
for o in i.ops:
if o.pos > 1: # 0 is return type, 1 is DXIL OP id
if not AccessorsWritten:
print(" // Accessors")
AccessorsWritten = True
print(
" llvm::Value *get_%s() const { return Instr->getOperand(%d); }"
% (o.name, o.pos - 1)
)
print(
" void set_%s(llvm::Value *val) { Instr->setOperand(%d, val); }"
% (o.name, o.pos - 1)
)
if o.is_const:
if o.llvm_type in self.llvm_type_map:
print(
" %s get_%s_val() const { return %s; }"
% (self.op_type(o), o.name, self.op_const_expr(o))
)
print(
" void set_%s_val(%s val) { Instr->setOperand(%d, %s); }"
% (
o.name,
self.op_type(o),
o.pos - 1,
self.op_set_const_expr(o),
)
)
print("};")
print("")
def print_footer(self):
print("} // namespace hlsl")
class db_enumhelp_gen:
"A generator of enumeration declarations."
def __init__(self, db):
self.db = db
# Some enums should get a last enum marker.
self.lastEnumNames = {"OpCode": "NumOpCodes", "OpCodeClass": "NumOpClasses"}
def print_enum(self, e, **kwargs):
print("// %s" % e.doc)
print("enum class %s : unsigned {" % e.name)
hide_val = kwargs.get("hide_val", False)
sorted_values = e.values
if kwargs.get("sort_val", True):
sorted_values = sorted(
e.values,
key=lambda v: ("" if v.category == None else v.category) + "." + v.name,
)
last_category = None
for v in sorted_values:
if v.category != last_category:
if last_category != None:
print("")
print(" // %s" % v.category)
last_category = v.category
line_format = " {name}"
if not e.is_internal and not hide_val:
line_format += " = {value}"
line_format += ","
if v.doc:
line_format += " // {doc}"
print(line_format.format(name=v.name, value=v.value, doc=v.doc))
if e.name in self.lastEnumNames:
lastName = self.lastEnumNames[e.name]
versioned = [
"%s_Dxil_%d_%d = %d," % (lastName, major, minor, info[lastName])
for (major, minor), info in sorted(self.db.dxil_version_info.items())
if lastName in info
]
if versioned:
print("")
for val in versioned:
print(" " + val)
print("")
print(
" "
+ lastName
+ " = "
+ str(len(sorted_values))
+ " // exclusive last value of enumeration"
)
print("};")
def print_rdat_enum(self, e, **kwargs):
nodef = kwargs.get("nodef", False)
for v in e.values:
line_format = (
"RDAT_ENUM_VALUE_NODEF({name})"
if nodef
else "RDAT_ENUM_VALUE({value}, {name})"
)
if v.doc:
line_format += " // {doc}"
print(line_format.format(name=v.name, value=v.value, doc=v.doc))
def print_content(self):
for e in sorted(self.db.enums, key=lambda e: e.name):
self.print_enum(e)
class db_oload_gen:
"A generator of overload tables."
def __init__(self, db):
self.db = db
instrs = [i for i in self.db.instr if i.is_dxil_op]
self.instrs = sorted(instrs, key=lambda i: i.dxil_opid)
# Allow these to be overridden by external scripts.
self.OP = "OP"
self.OC = "OC"
self.OCC = "OCC"
def print_content(self):
self.print_opfunc_props()
print("...")
self.print_opfunc_table()
def print_opfunc_props(self):
print(
"const {OP}::OpCodeProperty {OP}::m_OpCodeProps[(unsigned){OP}::OpCode::NumOpCodes] = {{".format(
OP=self.OP
)
)
print(
"// OpCode OpCode name, OpCodeClass OpCodeClass name, void, h, f, d, i1, i8, i16, i32, i64, udt, obj, function attribute"
)
# Example formatted string:
# { OC::TempRegLoad, "TempRegLoad", OCC::TempRegLoad, "tempRegLoad", false, true, true, false, true, false, true, true, false, Attribute::ReadOnly, },
# 012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0
last_category = None
# overload types are a string of (v)oid, (h)alf, (f)loat, (d)ouble, (1)-bit, (8)-bit, (w)ord, (i)nt, (l)ong, u(dt)
f = lambda i, c: "true" if i.oload_types.find(c) >= 0 else "false"
lower_exceptions = {
"CBufferLoad": "cbufferLoad",
"CBufferLoadLegacy": "cbufferLoadLegacy",
"GSInstanceID": "gsInstanceID",
}
lower_fn = (
lambda t: lower_exceptions[t]
if t in lower_exceptions
else t[:1].lower() + t[1:]
)
attr_dict = {
"": "None",
"ro": "ReadOnly",
"rn": "ReadNone",
"amo": "ArgMemOnly",
"nd": "NoDuplicate",
"nr": "NoReturn",
"wv": "None",
}
attr_fn = lambda i: "Attribute::" + attr_dict[i.fn_attr] + ","
for i in self.instrs:
if last_category != i.category:
if last_category != None:
print("")
print(
" // {category:118} void, h, f, d, i1, i8, i16, i32, i64, udt, obj , function attribute".format(
category=i.category
)
)
last_category = i.category
print(
" {{ {OC}::{name:24} {quotName:27} {OCC}::{className:25} {classNameQuot:28} {{{v:>6},{h:>6},{f:>6},{d:>6},{b:>6},{e:>6},{w:>6},{i:>6},{l:>6},{u:>6},{o:>6}}}, {attr:20} }},".format(
name=i.name + ",",
quotName='"' + i.name + '",',
className=i.dxil_class + ",",
classNameQuot='"' + lower_fn(i.dxil_class) + '",',
v=f(i, "v"),
h=f(i, "h"),
f=f(i, "f"),
d=f(i, "d"),
b=f(i, "1"),
e=f(i, "8"),
w=f(i, "w"),
i=f(i, "i"),
l=f(i, "l"),
u=f(i, "u"),
o=f(i, "o"),
attr=attr_fn(i),
OC=self.OC,
OCC=self.OCC,
)
)
print("};")
def print_opfunc_table(self):
# Print the table for OP::GetOpFunc
op_type_texts = {
"$cb": "CBRT(pETy);",
"$o": "A(pETy);",
"$r": "RRT(pETy);",
"d": "A(pF64);",
"dims": "A(pDim);",
"f": "A(pF32);",
"h": "A(pF16);",
"i1": "A(pI1);",
"i16": "A(pI16);",
"i32": "A(pI32);",
"i32c": "A(pI32C);",
"i64": "A(pI64);",
"i8": "A(pI8);",
"$u4": "A(pI4S);",
"pf32": "A(pPF32);",
"res": "A(pRes);",
"splitdouble": "A(pSDT);",
"twoi32": "A(p2I32);",
"twof32": "A(p2F32);",
"twof16": "A(p2F16);",
"twoi16": "A(p2I16);",
"threei32": "A(p3I32);",
"threef32": "A(p3F32);",
"fouri32": "A(p4I32);",
"fourf32": "A(p4F32);",
"fouri16": "A(p4I16);",
"fourf16": "A(p4F16);",
"u32": "A(pI32);",
"u64": "A(pI64);",
"u8": "A(pI8);",
"v": "A(pV);",
"$vec4": "VEC4(pETy);",
"w": "A(pWav);",
"SamplePos": "A(pPos);",
"udt": "A(udt);",
"obj": "A(obj);",
"resproperty": "A(resProperty);",
"resbind": "A(resBind);",
"waveMat": "A(pWaveMatPtr);",
"waveMatProps": "A(pWaveMatProps);",
"$gsptr": "A(pGSEltPtrTy);",
"nodehandle": "A(pNodeHandle);",
"noderecordhandle": "A(pNodeRecordHandle);",
"nodeproperty": "A(nodeProperty);",
"noderecordproperty": "A(nodeRecordProperty);",
}
last_category = None
for i in self.instrs:
if last_category != i.category:
if last_category != None:
print("")
print(" // %s" % i.category)
last_category = i.category
line = " case OpCode::{name:24}".format(name=i.name + ":")
for index, o in enumerate(i.ops):
assert (
o.llvm_type in op_type_texts
), "llvm type %s in instruction %s is unknown" % (o.llvm_type, i.name)
op_type_text = op_type_texts[o.llvm_type]
if index == 0:
line = line + "{val:13}".format(val=op_type_text)
else:
line = line + "{val:9}".format(val=op_type_text)
line = line + "break;"
print(line)
def print_opfunc_oload_type(self):
# Print the function for OP::GetOverloadType
elt_ty = "$o"
res_ret_ty = "$r"
cb_ret_ty = "$cb"
udt_ty = "udt"
obj_ty = "obj"
vec_ty = "$vec"
gsptr_ty = "$gsptr"
last_category = None
index_dict = collections.OrderedDict()
ptr_index_dict = collections.OrderedDict()
single_dict = collections.OrderedDict()
struct_list = []
for instr in self.instrs:
ret_ty = instr.ops[0].llvm_type
# Skip case return type is overload type
if ret_ty == elt_ty:
continue
if ret_ty == res_ret_ty:
struct_list.append(instr.name)
continue
if ret_ty == cb_ret_ty:
struct_list.append(instr.name)
continue
if ret_ty.startswith(vec_ty):
struct_list.append(instr.name)
continue
in_param_ty = False
# Try to find elt_ty in parameter types.
for index, op in enumerate(instr.ops):
# Skip return type.
if op.pos == 0:
continue
# Skip dxil opcode.
if op.pos == 1:
continue
op_type = op.llvm_type
if op_type == elt_ty:
# Skip return op
index = index - 1
if index not in index_dict:
index_dict[index] = [instr.name]
else:
index_dict[index].append(instr.name)
in_param_ty = True
break
if op_type == gsptr_ty:
# Skip return op
index = index - 1
if index not in ptr_index_dict:
ptr_index_dict[index] = [instr.name]
else:
ptr_index_dict[index].append(instr.name)
in_param_ty = True
break
if op_type == udt_ty or op_type == obj_ty:
# Skip return op
index = index - 1
if index not in index_dict:
index_dict[index] = [instr.name]
else:
index_dict[index].append(instr.name)
in_param_ty = True
if in_param_ty:
continue
# No overload, just return the single oload_type.
assert len(instr.oload_types) == 1, "overload no elt_ty %s" % (instr.name)
ty = instr.oload_types[0]
type_code_texts = {
"d": "Type::getDoubleTy(Ctx)",
"f": "Type::getFloatTy(Ctx)",
"h": "Type::getHalfTy",
"1": "IntegerType::get(Ctx, 1)",
"8": "IntegerType::get(Ctx, 8)",
"w": "IntegerType::get(Ctx, 16)",
"i": "IntegerType::get(Ctx, 32)",
"l": "IntegerType::get(Ctx, 64)",
"v": "Type::getVoidTy(Ctx)",
"u": "Type::getInt32PtrTy(Ctx)",
"o": "Type::getInt32PtrTy(Ctx)",
}
assert ty in type_code_texts, "llvm type %s is unknown" % (ty)
ty_code = type_code_texts[ty]
if ty_code not in single_dict:
single_dict[ty_code] = [instr.name]
else:
single_dict[ty_code].append(instr.name)
for index, opcodes in index_dict.items():
line = ""
for opcode in opcodes:
line = line + "case OpCode::{name}".format(name=opcode + ":\n")
line = (
line
+ " if (FT->getNumParams() <= "
+ str(index)
+ ") return nullptr;\n"
)
line = line + " return FT->getParamType(" + str(index) + ");"
print(line)
# ptr_index_dict for overload based on pointer element type
for index, opcodes in ptr_index_dict.items():
line = ""
for opcode in opcodes:
line = line + "case OpCode::{name}".format(name=opcode + ":\n")
line = (
line
+ " if (FT->getNumParams() <= "
+ str(index)
+ ") return nullptr;\n"
)
line = (
line
+ " return FT->getParamType("
+ str(index)
+ ")->getPointerElementType();"
)
print(line)
for code, opcodes in single_dict.items():
line = ""
for opcode in opcodes:
line = line + "case OpCode::{name}".format(name=opcode + ":\n")
line = line + " return " + code + ";"
print(line)
line = ""
for opcode in struct_list:
line = line + "case OpCode::{name}".format(name=opcode + ":\n")
line = line + "{\n"
line = line + " StructType *ST = cast<StructType>(Ty);\n"
line = line + " return ST->getElementType(0);\n"
line = line + "}"
print(line)
class db_valfns_gen:
"A generator of validation functions."
def __init__(self, db):
self.db = db
def print_content(self):
self.print_header()
self.print_body()
def print_header(self):
print(
"///////////////////////////////////////////////////////////////////////////////"
)
print(
"// Instruction validation functions. //"
)
def bool_lit(self, val):
return "true" if val else "false"
def op_type(self, o):
if o.llvm_type == "i8":
return "int8_t"
if o.llvm_type == "u8":
return "uint8_t"
raise ValueError(
"Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name)
)
def op_const_expr(self, o):
if o.llvm_type == "i8" or o.llvm_type == "u8":
return (
"(%s)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(%d))->getZExtValue())"
% (self.op_type(o), o.pos - 1)
)
raise ValueError(
"Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name)
)
def print_body(self):
llvm_instrs = [i for i in self.db.instr if i.is_allowed and not i.is_dxil_op]
print("static bool IsLLVMInstructionAllowed(llvm::Instruction &I) {")
self.print_comment(
" // ",
"Allow: %s"
% ", ".join([i.name + "=" + str(i.llvm_id) for i in llvm_instrs]),
)
print(" unsigned op = I.getOpcode();")
print(" return %s;" % build_range_code("op", [i.llvm_id for i in llvm_instrs]))
print("}")
print("")
def print_comment(self, prefix, val):
print(format_comment(prefix, val))
class macro_table_gen:
"A generator for macro tables."
def format_row(self, row, widths, sep=", "):
frow = [
str(item) + sep + (" " * (width - len(item)))
for item, width in list(zip(row, widths))[:-1]
] + [str(row[-1])]
return "".join(frow)
def format_table(self, table, *args, **kwargs):
widths = [
functools.reduce(max, [len(row[i]) for row in table], 1)
for i in range(len(table[0]))
]
formatted = []
for row in table:
formatted.append(self.format_row(row, widths, *args, **kwargs))
return formatted
def print_table(self, table, macro_name):
formatted = self.format_table(table)
print(
"// %s\n" % formatted[0]
+ "#define %s(ROW) \\\n" % macro_name
+ " \\\n".join([" ROW(%s)" % frow for frow in formatted[1:]])
)
class db_sigpoint_gen(macro_table_gen):
"A generator for SigPoint tables."
def __init__(self, db):
self.db = db
def print_sigpoint_table(self):
self.print_table(self.db.sigpoint_table, "DO_SIGPOINTS")
def print_interpretation_table(self):
self.print_table(self.db.interpretation_table, "DO_INTERPRETATION_TABLE")
def print_content(self):
self.print_sigpoint_table()
self.print_interpretation_table()
class string_output:
def __init__(self):
self.val = ""
def write(self, text):
self.val = self.val + str(text)
def __str__(self):
return self.val
def run_with_stdout(fn):
import sys
_stdout_saved = sys.stdout
so = string_output()
try:
sys.stdout = so
fn()
finally:
sys.stdout = _stdout_saved
return str(so)
def get_hlsl_intrinsic_stats():
db = get_db_hlsl()
longest_fn = db.intrinsics[0]
longest_param = None
longest_arglist_fn = db.intrinsics[0]
for i in sorted(db.intrinsics, key=lambda x: x.key):
# Get some values for maximum lengths.
if len(i.name) > len(longest_fn.name):
longest_fn = i
for p_idx, p in enumerate(i.params):
if p_idx > 0 and (
longest_param is None or len(p.name) > len(longest_param.name)
):
longest_param = p
if len(i.params) > len(longest_arglist_fn.params):
longest_arglist_fn = i
result = ""
for k in sorted(db.namespaces.keys()):
v = db.namespaces[k]
result += "static const UINT g_u%sCount = %d;\n" % (k, len(v.intrinsics))
result += "\n"
# NOTE:The min limits are needed to support allowing intrinsics in the extension mechanism that use longer values than the builtin hlsl intrisics.
# TODO: remove code which dependent on g_MaxIntrinsic*.
MIN_FUNCTION_NAME_LENTH = 44
MIN_PARAM_NAME_LENTH = 48
MIN_PARAM_COUNT = 29
max_fn_name = longest_fn.name
max_fn_name_len = len(longest_fn.name)
max_param_name = longest_param.name
max_param_name_len = len(longest_param.name)
max_param_count_name = longest_arglist_fn.name
max_param_count = len(longest_arglist_fn.params) - 1
if max_fn_name_len < MIN_FUNCTION_NAME_LENTH:
max_fn_name_len = MIN_FUNCTION_NAME_LENTH
max_fn_name = "MIN_FUNCTION_NAME_LENTH"
if max_param_name_len < MIN_PARAM_NAME_LENTH:
max_param_name_len = MIN_PARAM_NAME_LENTH
max_param_name = "MIN_PARAM_NAME_LENTH"
if max_param_count < MIN_PARAM_COUNT:
max_param_count = MIN_PARAM_COUNT
max_param_count_name = "MIN_PARAM_COUNT"
result += (
"static const int g_MaxIntrinsicName = %d; // Count of characters for longest intrinsic name - '%s'\n"
% (max_fn_name_len, max_fn_name)
)
result += (
"static const int g_MaxIntrinsicParamName = %d; // Count of characters for longest intrinsic parameter name - '%s'\n"
% (max_param_name_len, max_param_name)
)
result += (
"static const int g_MaxIntrinsicParamCount = %d; // Count of parameters (without return) for longest intrinsic argument list - '%s'\n"
% (max_param_count, max_param_count_name)
)
return result
def get_hlsl_intrinsics():
db = get_db_hlsl()
result = ""
last_ns = ""
ns_table = ""
is_vk_table = False # SPIRV Change
id_prefix = ""
arg_idx = 0
opcode_namespace = db.opcode_namespace
for i in sorted(db.intrinsics, key=lambda x: x.key):
if last_ns != i.ns:
last_ns = i.ns
id_prefix = (
"IOP" if last_ns == "Intrinsics" or last_ns == "VkIntrinsics" else "MOP"
) # SPIRV Change
if len(ns_table):
result += ns_table + "};\n"
# SPIRV Change Starts
if is_vk_table:
result += "\n#endif // ENABLE_SPIRV_CODEGEN\n"
is_vk_table = False
# SPIRV Change Ends
result += "\n//\n// Start of %s\n//\n\n" % (last_ns)
# This used to be qualified as __declspec(selectany), but that's no longer necessary.
ns_table = "static const HLSL_INTRINSIC g_%s[] =\n{\n" % (last_ns)
# SPIRV Change Starts
if i.vulkanSpecific:
is_vk_table = True
result += "#ifdef ENABLE_SPIRV_CODEGEN\n\n"
# SPIRV Change Ends
arg_idx = 0
ns_table += " {(UINT)%s::%s_%s, %s, %s, %s, %d, %d, g_%s_Args%s},\n" % (
opcode_namespace,
id_prefix,
i.name,
str(i.readonly).lower(),
str(i.readnone).lower(),
str(i.wave).lower(),
i.overload_param_index,
len(i.params),
last_ns,
arg_idx,
)
result += "static const HLSL_INTRINSIC_ARGUMENT g_%s_Args%s[] =\n{\n" % (
last_ns,
arg_idx,
)
for p in i.params:
name = p.name
if name == i.name and i.hidden:
# First parameter defines intrinsic name for parsing in HLSL.
# Prepend '$hidden$' for hidden intrinsic so it can't be used in HLSL.
name = "$hidden$" + name
result += ' {"%s", %s, %s, %s, %s, %s, %s, %s},\n' % (
name,
p.param_qual,
p.template_id,
p.template_list,
p.component_id,
p.component_list,
p.rows,
p.cols,
)
result += "};\n\n"
arg_idx += 1
result += ns_table + "};\n"
result += (
"\n#endif // ENABLE_SPIRV_CODEGEN\n" if is_vk_table else ""
) # SPIRV Change
return result
# SPIRV Change Starts
def wrap_with_ifdef_if_vulkan_specific(intrinsic, text):
if intrinsic.vulkanSpecific:
return (
"#ifdef ENABLE_SPIRV_CODEGEN\n" + text + "#endif // ENABLE_SPIRV_CODEGEN\n"
)
return text
# SPIRV Change Ends
def enum_hlsl_intrinsics():
db = get_db_hlsl()
result = ""
enumed = []
for i in sorted(db.intrinsics, key=lambda x: x.key):
if i.enum_name not in enumed:
enumerant = " %s,\n" % (i.enum_name)
result += wrap_with_ifdef_if_vulkan_specific(i, enumerant) # SPIRV Change
enumed.append(i.enum_name)
# unsigned
result += " // unsigned\n"
for i in sorted(db.intrinsics, key=lambda x: x.key):
if i.unsigned_op != "":
if i.unsigned_op not in enumed:
result += " %s,\n" % (i.unsigned_op)
enumed.append(i.unsigned_op)
result += " Num_Intrinsics,\n"
return result
def has_unsigned_hlsl_intrinsics():
db = get_db_hlsl()
result = ""
enumed = []
# unsigned
for i in sorted(db.intrinsics, key=lambda x: x.key):
if i.unsigned_op != "":
if i.enum_name not in enumed:
result += " case IntrinsicOp::%s:\n" % (i.enum_name)
enumed.append(i.enum_name)
return result
def get_unsigned_hlsl_intrinsics():
db = get_db_hlsl()
result = ""
enumed = []
# unsigned
for i in sorted(db.intrinsics, key=lambda x: x.key):
if i.unsigned_op != "":
if i.enum_name not in enumed:
enumed.append(i.enum_name)
result += " case IntrinsicOp::%s:\n" % (i.enum_name)
result += " return static_cast<unsigned>(IntrinsicOp::%s);\n" % (
i.unsigned_op
)
return result
def get_oloads_props():
db = get_db_dxil()
gen = db_oload_gen(db)
return run_with_stdout(lambda: gen.print_opfunc_props())
def get_oloads_funcs():
db = get_db_dxil()
gen = db_oload_gen(db)
return run_with_stdout(lambda: gen.print_opfunc_table())
def get_funcs_oload_type():
db = get_db_dxil()
gen = db_oload_gen(db)
return run_with_stdout(lambda: gen.print_opfunc_oload_type())
def get_enum_decl(name, **kwargs):
db = get_db_dxil()
gen = db_enumhelp_gen(db)
return run_with_stdout(lambda: gen.print_enum(db.enum_idx[name], **kwargs))
def get_rdat_enum_decl(name, **kwargs):
db = get_db_dxil()
gen = db_enumhelp_gen(db)
return run_with_stdout(lambda: gen.print_rdat_enum(db.enum_idx[name], **kwargs))
def get_valrule_enum():
return get_enum_decl("ValidationRule", hide_val=True)
def get_valrule_text():
db = get_db_dxil()
result = "switch(value) {\n"
for v in db.enum_idx["ValidationRule"].values:
result += (
" case hlsl::ValidationRule::" + v.name + ': return "' + v.err_msg + '";\n'
)
result += "}\n"
return result
def get_instrhelper():
db = get_db_dxil()
gen = db_instrhelp_gen(db)
return run_with_stdout(lambda: gen.print_body())
def get_instrs_pred(varname, pred, attr_name="dxil_opid"):
db = get_db_dxil()
if type(pred) == str:
pred_fn = lambda i: getattr(i, pred)
else:
pred_fn = pred
llvm_instrs = [i for i in db.instr if pred_fn(i)]
result = format_comment(
"// ",
"Instructions: %s"
% ", ".join([i.name + "=" + str(getattr(i, attr_name)) for i in llvm_instrs]),
)
result += "return %s;" % build_range_code(
varname, [getattr(i, attr_name) for i in llvm_instrs]
)
result += "\n"
return result
def counter_pred(name, dxil_op=True):
def pred(i):
return (
(dxil_op == i.is_dxil_op)
and getattr(i, "props")
and "counters" in i.props
and name in i.props["counters"]
)
return pred
def get_counters():
db = get_db_dxil()
return db.counters
def get_llvm_op_counters():
db = get_db_dxil()
return [c for c in db.counters if c in db.llvm_op_counters]
def get_dxil_op_counters():
db = get_db_dxil()
return [c for c in db.counters if c in db.dxil_op_counters]
def get_instrs_rst():
"Create an rst table of allowed LLVM instructions."
db = get_db_dxil()
instrs = [i for i in db.instr if i.is_allowed and not i.is_dxil_op]
instrs = sorted(instrs, key=lambda v: v.llvm_id)
rows = []
rows.append(["Instruction", "Action", "Operand overloads"])
for i in instrs:
rows.append([i.name, i.doc, i.oload_types])
result = "\n\n" + format_rst_table(rows) + "\n\n"
# Add detailed instruction information where available.
for i in instrs:
if i.remarks:
result += i.name + "\n" + ("~" * len(i.name)) + "\n\n" + i.remarks + "\n\n"
return result + "\n"
def get_init_passes(category_libs):
"Create a series of statements to initialize passes in a registry."
db = get_db_dxil()
result = ""
for p in sorted(db.passes, key=lambda p: p.type_name):
# Skip if not in target category.
if p.category_lib not in category_libs:
continue
result += "initialize%sPass(Registry);\n" % p.type_name
return result
def get_pass_arg_names():
"Return an ArrayRef of argument names based on passName"
db = get_db_dxil()
decl_result = ""
check_result = ""
for p in sorted(db.passes, key=lambda p: p.type_name):
if len(p.args):
decl_result += "static const LPCSTR %sArgs[] = { " % p.type_name
check_result += (
'if (strcmp(passName, "%s") == 0) return ArrayRef<LPCSTR>(%sArgs, _countof(%sArgs));\n'
% (p.name, p.type_name, p.type_name)
)
sep = ""
for a in p.args:
decl_result += sep + '"%s"' % a.name
sep = ", "
decl_result += " };\n"
return decl_result + check_result
def get_pass_arg_descs():
"Return an ArrayRef of argument descriptions based on passName"
db = get_db_dxil()
decl_result = ""
check_result = ""
for p in sorted(db.passes, key=lambda p: p.type_name):
if len(p.args):
decl_result += "static const LPCSTR %sArgs[] = { " % p.type_name
check_result += (
'if (strcmp(passName, "%s") == 0) return ArrayRef<LPCSTR>(%sArgs, _countof(%sArgs));\n'
% (p.name, p.type_name, p.type_name)
)
sep = ""
for a in p.args:
decl_result += sep + '"%s"' % a.doc
sep = ", "
decl_result += " };\n"
return decl_result + check_result
def get_is_pass_option_name():
"Create a return expression to check whether a value 'S' is a pass option name."
db = get_db_dxil()
prefix = ""
result = "return "
for k in sorted(db.pass_idx_args):
result += prefix + 'S.equals("%s")' % k
prefix = "\n || "
return result + ";"
def get_opcodes_rst():
"Create an rst table of opcodes"
db = get_db_dxil()
instrs = [i for i in db.instr if i.is_allowed and i.is_dxil_op]
instrs = sorted(instrs, key=lambda v: v.dxil_opid)
rows = []
rows.append(["ID", "Name", "Description"])
for i in instrs:
op_name = i.dxil_op
if i.remarks:
op_name = (
op_name + "_"
) # append _ to enable internal hyperlink on rst files
rows.append([i.dxil_opid, op_name, i.doc])
result = "\n\n" + format_rst_table(rows) + "\n\n"
# Add detailed instruction information where available.
instrs = sorted(instrs, key=lambda v: v.name)
for i in instrs:
if i.remarks:
result += i.name + "\n" + ("~" * len(i.name)) + "\n\n" + i.remarks + "\n\n"
return result + "\n"
def get_valrules_rst():
"Create an rst table of validation rules instructions."
db = get_db_dxil()
rules = [i for i in db.val_rules if not i.is_disabled]
rules = sorted(rules, key=lambda v: v.name)
rows = []
rows.append(["Rule Code", "Description"])
for i in rules:
rows.append([i.name, i.doc])
return "\n\n" + format_rst_table(rows) + "\n\n"
def get_opsigs():
# Create a list of DXIL operation signatures, sorted by ID.
db = get_db_dxil()
instrs = [i for i in db.instr if i.is_dxil_op]
instrs = sorted(instrs, key=lambda v: v.dxil_opid)
# db_dxil already asserts that the numbering is dense.
# Create the code to write out.
code = "static const char *OpCodeSignatures[] = {\n"
for inst_idx, i in enumerate(instrs):
code += ' "('
for operand in i.ops:
if operand.pos > 1: # skip 0 (the return value) and 1 (the opcode itself)
code += operand.name
if operand.pos < len(i.ops) - 1:
code += ","
code += ')"'
if inst_idx < len(instrs) - 1:
code += ","
code += " // " + i.name
code += "\n"
code += "};\n"
return code
shader_stage_to_ShaderKind = {
"vertex": "Vertex",
"pixel": "Pixel",
"geometry": "Geometry",
"compute": "Compute",
"hull": "Hull",
"domain": "Domain",
"library": "Library",
"raygeneration": "RayGeneration",
"intersection": "Intersection",
"anyhit": "AnyHit",
"closesthit": "ClosestHit",
"miss": "Miss",
"callable": "Callable",
"mesh": "Mesh",
"amplification": "Amplification",
"node": "Node",
}
def get_min_sm_and_mask_text():
db = get_db_dxil()
instrs = [i for i in db.instr if i.is_dxil_op]
instrs = sorted(
instrs,
key=lambda v: (
v.shader_model,
v.shader_model_translated,
v.shader_stages,
v.dxil_opid,
),
)
last_model = None
last_model_translated = None
last_stage = None
grouped_instrs = []
code = ""
def flush_instrs(grouped_instrs, last_model, last_model_translated, last_stage):
if len(grouped_instrs) == 0:
return ""
result = format_comment(
"// ",
"Instructions: %s"
% ", ".join([i.name + "=" + str(i.dxil_opid) for i in grouped_instrs]),
)
result += (
"if ("
+ build_range_code("op", [i.dxil_opid for i in grouped_instrs])
+ ") {\n"
)
default = True
if last_model != (6, 0):
default = False
if last_model_translated:
result += " if (bWithTranslation) {\n"
result += (
" major = %d; minor = %d;\n } else {\n "
% last_model_translated
)
result += " major = %d; minor = %d;\n" % last_model
if last_model_translated:
result += " }\n"
if last_stage:
default = False
result += " mask = %s;\n" % " | ".join(
["SFLAG(%s)" % shader_stage_to_ShaderKind[c] for c in last_stage]
)
if default:
# don't write these out, instead fall through
return ""
return result + " return;\n}\n"
for i in instrs:
if (i.shader_model, i.shader_model_translated, i.shader_stages) != (
last_model,
last_model_translated,
last_stage,
):
code += flush_instrs(
grouped_instrs, last_model, last_model_translated, last_stage
)
grouped_instrs = []
last_model = i.shader_model
last_model_translated = i.shader_model_translated
last_stage = i.shader_stages
grouped_instrs.append(i)
code += flush_instrs(grouped_instrs, last_model, last_model_translated, last_stage)
return code
check_pSM_for_shader_stage = {
"vertex": "SK == DXIL::ShaderKind::Vertex",
"pixel": "SK == DXIL::ShaderKind::Pixel",
"geometry": "SK == DXIL::ShaderKind::Geometry",
"compute": "SK == DXIL::ShaderKind::Compute",
"hull": "SK == DXIL::ShaderKind::Hull",
"domain": "SK == DXIL::ShaderKind::Domain",
"library": "SK == DXIL::ShaderKind::Library",
"raygeneration": "SK == DXIL::ShaderKind::RayGeneration",
"intersection": "SK == DXIL::ShaderKind::Intersection",
"anyhit": "SK == DXIL::ShaderKind::AnyHit",
"closesthit": "SK == DXIL::ShaderKind::ClosestHit",
"miss": "SK == DXIL::ShaderKind::Miss",
"callable": "SK == DXIL::ShaderKind::Callable",
"mesh": "SK == DXIL::ShaderKind::Mesh",
"amplification": "SK == DXIL::ShaderKind::Amplification",
"node": "SK == DXIL::ShaderKind::Node",
}
def get_valopcode_sm_text():
db = get_db_dxil()
instrs = [i for i in db.instr if i.is_dxil_op]
instrs = sorted(
instrs, key=lambda v: (v.shader_model, v.shader_stages, v.dxil_opid)
)
last_model = None
last_stage = None
grouped_instrs = []
code = ""
def flush_instrs(grouped_instrs, last_model, last_stage):
if len(grouped_instrs) == 0:
return ""
result = format_comment(
"// ",
"Instructions: %s"
% ", ".join([i.name + "=" + str(i.dxil_opid) for i in grouped_instrs]),
)
result += (
"if ("
+ build_range_code("op", [i.dxil_opid for i in grouped_instrs])
+ ")\n"
)
result += " return "
model_cond = stage_cond = None
if last_model != (6, 0):
model_cond = "major > %d || (major == %d && minor >= %d)" % (
last_model[0],
last_model[0],
last_model[1],
)
if last_stage:
stage_cond = " || ".join(
[check_pSM_for_shader_stage[c] for c in last_stage]
)
if model_cond or stage_cond:
result += "\n && ".join(
["(%s)" % expr for expr in (model_cond, stage_cond) if expr]
)
return result + ";\n"
else:
# don't write these out, instead fall through
return ""
for i in instrs:
if (i.shader_model, i.shader_stages) != (last_model, last_stage):
code += flush_instrs(grouped_instrs, last_model, last_stage)
grouped_instrs = []
last_model = i.shader_model
last_stage = i.shader_stages
grouped_instrs.append(i)
code += flush_instrs(grouped_instrs, last_model, last_stage)
code += "return true;\n"
return code
def get_sigpoint_table():
db = get_db_dxil()
gen = db_sigpoint_gen(db)
return run_with_stdout(lambda: gen.print_sigpoint_table())
def get_sigpoint_rst():
"Create an rst table for SigPointKind."
db = get_db_dxil()
rows = [row[:] for row in db.sigpoint_table[:-1]] # Copy table
e = dict([(v.name, v) for v in db.enum_idx["SigPointKind"].values])
rows[0] = ["ID"] + rows[0] + ["Description"]
for i in range(1, len(rows)):
row = rows[i]
v = e[row[0]]
rows[i] = [v.value] + row + [v.doc]
return "\n\n" + format_rst_table(rows) + "\n\n"
def get_sem_interpretation_enum_rst():
db = get_db_dxil()
rows = [["ID", "Name", "Description"]] + [
[v.value, v.name, v.doc]
for v in db.enum_idx["SemanticInterpretationKind"].values[:-1]
]
return "\n\n" + format_rst_table(rows) + "\n\n"
def get_sem_interpretation_table_rst():
db = get_db_dxil()
return "\n\n" + format_rst_table(db.interpretation_table) + "\n\n"
def get_interpretation_table():
db = get_db_dxil()
gen = db_sigpoint_gen(db)
return run_with_stdout(lambda: gen.print_interpretation_table())
highest_major = 6
highest_minor = 8
highest_shader_models = {4: 1, 5: 1, 6: highest_minor}
def getShaderModels():
shader_models = []
for major, minor in highest_shader_models.items():
for i in range(0, minor + 1):
shader_models.append(str(major) + "_" + str(i))
return shader_models
def get_highest_shader_model():
result = """static const unsigned kHighestMajor = %d;
static const unsigned kHighestMinor = %d;""" % (
highest_major,
highest_minor,
)
return result
def get_dxil_version_minor():
return "const unsigned kDxilMinor = %d;" % highest_minor
def get_dxil_version_minor_int():
return highest_minor
def get_is_shader_model_plus():
result = ""
for i in range(0, highest_minor + 1):
result += "bool IsSM%d%dPlus() const { return IsSMAtLeast(%d, %d); }\n" % (
highest_major,
i,
highest_major,
i,
)
return result
profile_to_kind = {
"ps": "Kind::Pixel",
"vs": "Kind::Vertex",
"gs": "Kind::Geometry",
"hs": "5_0",
"ds": "5_0",
"cs": "4_0",
"lib": "6_1",
"ms": "6_5",
"as": "6_5",
}
class shader_profile(object):
"The profile description for a DXIL instruction"
def __init__(self, kind, kind_name, enum_name, start_sm, input_size, output_size):
self.kind = kind # position in parameter list
self.kind_name = kind_name
self.enum_name = enum_name
self.start_sm = start_sm
self.input_size = input_size
self.output_size = output_size
# kind is from DXIL::ShaderKind.
shader_profiles = [
shader_profile(0, "ps", "Kind::Pixel", "4_0", 32, 8),
shader_profile(1, "vs", "Kind::Vertex", "4_0", 32, 32),
shader_profile(2, "gs", "Kind::Geometry", "4_0", 32, 32),
shader_profile(3, "hs", "Kind::Hull", "5_0", 32, 32),
shader_profile(4, "ds", "Kind::Domain", "5_0", 32, 32),
shader_profile(5, "cs", "Kind::Compute", "4_0", 0, 0),
shader_profile(6, "lib", "Kind::Library", "6_1", 32, 32),
shader_profile(13, "ms", "Kind::Mesh", "6_5", 0, 0),
shader_profile(14, "as", "Kind::Amplification", "6_5", 0, 0),
]
def getShaderProfiles():
# order match DXIL::ShaderKind.
profiles = (
("ps", "4_0"),
("vs", "4_0"),
("gs", "4_0"),
("hs", "5_0"),
("ds", "5_0"),
("cs", "4_0"),
("lib", "6_1"),
("ms", "6_5"),
("as", "6_5"),
)
return profiles
def get_shader_models():
result = ""
for profile in shader_profiles:
min_sm = profile.start_sm
input_size = profile.input_size
output_size = profile.output_size
kind = profile.kind
kind_name = profile.kind_name
enum_name = profile.enum_name
for major, minor in highest_shader_models.items():
UAV_info = "true, true, UINT_MAX"
if major > 5:
pass
elif major == 4:
UAV_info = "false, false, 0"
if kind == "cs":
UAV_info = "true, false, 1"
elif major == 5:
UAV_info = "true, true, 64"
for i in range(0, minor + 1):
sm = "%d_%d" % (major, i)
if min_sm > sm:
continue
input_size = profile.input_size
output_size = profile.output_size
if major == 4:
if i == 0:
if kind_name == "gs":
input_size = 16
elif kind_name == "vs":
input_size = 16
output_size = 16
sm_name = "%s_%s" % (kind_name, sm)
result += 'SM(%s, %d, %d, "%s", %d, %d, %s),\n' % (
enum_name,
major,
i,
sm_name,
input_size,
output_size,
UAV_info,
)
if kind_name == "lib":
result += (
"// lib_6_x is for offline linking only, and relaxes restrictions\n"
)
result += 'SM(Kind::Library, 6, kOfflineMinor, "lib_6_x", 32, 32, true, true, UINT_MAX),\n'
result += (
"// Values before Invalid must remain sorted by Kind, then Major, then Minor.\n"
)
result += 'SM(Kind::Invalid, 0, 0, "invalid", 0, 0, false, false, 0),\n'
return result
def get_num_shader_models():
count = 0
for profile in shader_profiles:
min_sm = profile.start_sm
input_size = profile.input_size
output_size = profile.output_size
kind = profile.kind
kind_name = profile.kind_name
enum_name = profile.enum_name
for major, minor in highest_shader_models.items():
for i in range(0, minor + 1):
sm = "%d_%d" % (major, i)
if min_sm > sm:
continue
count += 1
if kind_name == "lib":
# for lib_6_x
count += 1
# for invalid shader_model.
count += 1
return "static const unsigned kNumShaderModels = %d;" % count
def build_shader_model_hash_idx_map():
# must match get_shader_models.
result = "const static std::pair<unsigned, unsigned> hashToIdxMap[] = {\n"
count = 0
for profile in shader_profiles:
min_sm = profile.start_sm
kind = profile.kind
kind_name = profile.kind_name
for major, minor in highest_shader_models.items():
for i in range(0, minor + 1):
sm = "%d_%d" % (major, i)
if min_sm > sm:
continue
sm_name = "%s_%s" % (kind_name, sm)
hash_v = kind << 16 | major << 8 | i
result += "{%d,%d}, //%s\n" % (hash_v, count, sm_name)
count += 1
if kind_name == "lib":
result += (
"// lib_6_x is for offline linking only, and relaxes restrictions\n"
)
major = 6
# static const unsigned kOfflineMinor = 0xF;
i = 15
hash_v = kind << 16 | major << 8 | i
result += "{%d,%d},//%s\n" % (hash_v, count, "lib_6_x")
count += 1
result += "};\n"
return result
def get_validation_version():
result = (
"""// 1.0 is the first validator.
// 1.1 adds:
// - ILDN container part support
// 1.2 adds:
// - Metadata for floating point denorm mode
// 1.3 adds:
// - Library support
// - Raytracing support
// - i64/f64 overloads for rawBufferLoad/Store
// 1.4 adds:
// - packed u8x4/i8x4 dot with accumulate to i32
// - half dot2 with accumulate to float
// 1.5 adds:
// - WaveMatch, WaveMultiPrefixOp, WaveMultiPrefixBitCount
// - HASH container part support
// - Mesh and Amplification shaders
// - DXR 1.1 & RayQuery support
*pMajor = 1;
*pMinor = %d;
"""
% highest_minor
)
return result
def get_target_profiles():
result = 'HelpText<"Set target profile. \\n'
result += "\\t<profile>: "
profiles = getShaderProfiles()
shader_models = getShaderModels()
base_sm = "%d_0" % highest_major
for profile, min_sm in profiles:
for shader_model in shader_models:
if base_sm > shader_model:
continue
if min_sm > shader_model:
continue
result += "%s_%s, " % (profile, shader_model)
result += "\\n\\t\\t "
result += '">;'
return result
def get_min_validator_version():
result = ""
for i in range(0, highest_minor + 1):
result += "case %d:\n" % i
result += " ValMinor = %d;\n" % i
result += " break;\n"
return result
def get_dxil_version():
result = ""
for i in range(0, highest_minor + 1):
result += "case %d:\n" % i
result += " DxilMinor = %d;\n" % i
result += " break;\n"
result += "case kOfflineMinor: // Always update this to highest dxil version\n"
result += " DxilMinor = %d;\n" % highest_minor
result += " break;\n"
return result
def get_shader_model_get():
# const static std::pair<unsigned, unsigned> hashToIdxMap[] = {};
result = build_shader_model_hash_idx_map()
result += "unsigned hash = (unsigned)Kind << 16 | Major << 8 | Minor;\n"
result += "auto pred = [](const std::pair<unsigned, unsigned>& elem, unsigned val){ return elem.first < val;};\n"
result += "auto it = std::lower_bound(std::begin(hashToIdxMap), std::end(hashToIdxMap), hash, pred);\n"
result += "if (it == std::end(hashToIdxMap) || it->first != hash)\n"
result += " return GetInvalid();\n"
result += "return &ms_ShaderModels[it->second];"
return result
def get_shader_model_by_name():
result = ""
for i in range(2, highest_minor + 1):
result += "case '%d':\n" % i
result += " if (Major == %d) {\n" % highest_major
result += " Minor = %d;\n" % i
result += " break;\n"
result += " }\n"
result += "else return GetInvalid();\n"
return result
def get_is_valid_for_dxil():
result = ""
for i in range(0, highest_minor + 1):
result += "case %d:\n" % i
return result
def RunCodeTagUpdate(file_path):
import os
import CodeTags
print(" ... updating " + file_path)
args = [file_path, file_path + ".tmp"]
result = CodeTags.main(args)
if result != 0:
print(" ... error: %d" % result)
else:
with open(file_path, "rt") as f:
before = f.read()
with open(file_path + ".tmp", "rt") as f:
after = f.read()
if before == after:
print(" --- no changes found")
else:
print(" +++ changes found, updating file")
with open(file_path, "wt") as f:
f.write(after)
os.remove(file_path + ".tmp")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate code to handle instructions."
)
parser.add_argument(
"-gen",
choices=["docs-ref", "docs-spec", "inst-header", "enums", "oloads", "valfns"],
help="Output type to generate.",
)
parser.add_argument("-update-files", action="store_const", const=True)
args = parser.parse_args()
db = get_db_dxil() # used by all generators, also handy to have it run validation
if args.gen == "docs-ref":
gen = db_docsref_gen(db)
gen.print_content()
if args.gen == "docs-spec":
import os, docutils.core
assert (
"HLSL_SRC_DIR" in os.environ
), "Environment variable HLSL_SRC_DIR is not defined"
hlsl_src_dir = os.environ["HLSL_SRC_DIR"]
spec_file = os.path.abspath(os.path.join(hlsl_src_dir, "docs/DXIL.rst"))
with open(spec_file) as f:
s = docutils.core.publish_file(f, writer_name="html")
if args.gen == "inst-header":
gen = db_instrhelp_gen(db)
gen.print_content()
if args.gen == "enums":
gen = db_enumhelp_gen(db)
gen.print_content()
if args.gen == "oloads":
gen = db_oload_gen(db)
gen.print_content()
if args.gen == "valfns":
gen = db_valfns_gen(db)
gen.print_content()
if args.update_files:
print("Updating files ...")
import CodeTags
import os
assert (
"HLSL_SRC_DIR" in os.environ
), "Environment variable HLSL_SRC_DIR is not defined"
hlsl_src_dir = os.environ["HLSL_SRC_DIR"]
pj = lambda *parts: os.path.abspath(os.path.join(*parts))
files = [
"docs/DXIL.rst",
"lib/DXIL/DxilOperations.cpp",
"lib/DXIL/DxilShaderModel.cpp",
"include/dxc/DXIL/DxilConstants.h",
"include/dxc/DXIL/DxilShaderModel.h",
"include/dxc/HLSL/DxilValidation.h",
"include/dxc/Support/HLSLOptions.td",
"include/dxc/DXIL/DxilInstructions.h",
"lib/HLSL/DxcOptimizer.cpp",
"lib/DxilPIXPasses/DxilPIXPasses.cpp",
"lib/HLSL/DxilValidation.cpp",
"tools/clang/lib/Sema/gen_intrin_main_tables_15.h",
"include/dxc/HlslIntrinsicOp.h",
"tools/clang/tools/dxcompiler/dxcdisassembler.cpp",
"include/dxc/DXIL/DxilSigPoint.inl",
"include/dxc/DXIL/DxilCounters.h",
"lib/DXIL/DxilCounters.cpp",
"lib/DXIL/DxilMetadataHelper.cpp",
"include/dxc/DxilContainer/RDAT_LibraryTypes.inl",
]
for relative_file_path in files:
RunCodeTagUpdate(pj(hlsl_src_dir, relative_file_path))