slangpy
Advanced tools
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| import pytest | ||
| import numpy as np | ||
| import slangpy as spy | ||
| from slangpy import DeviceType, Module | ||
| from slangpy.types.tensor import Tensor | ||
| from slangpy.testing import helpers | ||
| def build_blas( | ||
| device: spy.Device, vertices: np.ndarray, indices: np.ndarray | ||
| ) -> spy.AccelerationStructure: | ||
| vertex_buffer = device.create_buffer( | ||
| usage=spy.BufferUsage.shader_resource, | ||
| label="vertex_buffer", | ||
| data=vertices, | ||
| ) | ||
| index_buffer = device.create_buffer( | ||
| usage=spy.BufferUsage.shader_resource, | ||
| label="index_buffer", | ||
| data=indices, | ||
| ) | ||
| blas_input_triangles = spy.AccelerationStructureBuildInputTriangles( | ||
| { | ||
| "vertex_buffers": [vertex_buffer], | ||
| "vertex_format": spy.Format.rgb32_float, | ||
| "vertex_count": vertices.size // 3, | ||
| "vertex_stride": vertices.itemsize * 3, | ||
| "index_buffer": index_buffer, | ||
| "index_format": spy.IndexFormat.uint32, | ||
| "index_count": indices.size, | ||
| "flags": spy.AccelerationStructureGeometryFlags.opaque, | ||
| } | ||
| ) | ||
| blas_build_desc = spy.AccelerationStructureBuildDesc( | ||
| { | ||
| "inputs": [blas_input_triangles], | ||
| } | ||
| ) | ||
| blas_sizes = device.get_acceleration_structure_sizes(blas_build_desc) | ||
| blas_scratch_buffer = device.create_buffer( | ||
| size=blas_sizes.scratch_size, | ||
| usage=spy.BufferUsage.unordered_access, | ||
| label="blas_scratch_buffer", | ||
| ) | ||
| blas = device.create_acceleration_structure( | ||
| size=blas_sizes.acceleration_structure_size, | ||
| label="blas", | ||
| ) | ||
| command_encoder = device.create_command_encoder() | ||
| command_encoder.build_acceleration_structure( | ||
| desc=blas_build_desc, dst=blas, src=None, scratch_buffer=blas_scratch_buffer | ||
| ) | ||
| device.submit_command_buffer(command_encoder.finish()) | ||
| return blas | ||
| def build_tlas( | ||
| device: spy.Device, instance_list: spy.AccelerationStructureInstanceList | ||
| ) -> spy.AccelerationStructure: | ||
| tlas_build_desc = spy.AccelerationStructureBuildDesc( | ||
| { | ||
| "inputs": [instance_list.build_input_instances()], | ||
| } | ||
| ) | ||
| tlas_sizes = device.get_acceleration_structure_sizes(tlas_build_desc) | ||
| tlas_scratch_buffer = device.create_buffer( | ||
| size=tlas_sizes.scratch_size, | ||
| usage=spy.BufferUsage.unordered_access, | ||
| label="tlas_scratch_buffer", | ||
| ) | ||
| tlas = device.create_acceleration_structure( | ||
| size=tlas_sizes.acceleration_structure_size, | ||
| label="tlas", | ||
| ) | ||
| command_encoder = device.create_command_encoder() | ||
| command_encoder.build_acceleration_structure( | ||
| desc=tlas_build_desc, dst=tlas, src=None, scratch_buffer=tlas_scratch_buffer | ||
| ) | ||
| device.submit_command_buffer(command_encoder.finish()) | ||
| return tlas | ||
| @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) | ||
| def test_raytracing(device_type: DeviceType): | ||
| device = helpers.get_device(device_type) | ||
| if not device.has_feature(spy.Feature.acceleration_structure): | ||
| pytest.skip("Acceleration structures not supported on this device") | ||
| if not device.has_feature(spy.Feature.ray_tracing): | ||
| pytest.skip("Ray tracing not supported on this device") | ||
| vertices = np.array([-1, -1, 0, 1, -1, 0, -1, 1, 0], dtype=np.float32) | ||
| indices = np.array([0, 1, 2], dtype=np.uint32) | ||
| blas = build_blas(device, vertices, indices) | ||
| instance_list = device.create_acceleration_structure_instance_list(1) | ||
| instance_list.write( | ||
| 0, | ||
| { | ||
| "transform": spy.float3x4.identity(), | ||
| "instance_id": 0, | ||
| "instance_mask": 0xFF, | ||
| "instance_contribution_to_hit_group_index": 0, | ||
| "flags": spy.AccelerationStructureInstanceFlags.none, | ||
| "acceleration_structure": blas.handle, | ||
| }, | ||
| ) | ||
| tlas = build_tlas(device, instance_list) | ||
| tensor = Tensor.zeros(device, (64, 64, 3), dtype=float) | ||
| module = Module(device.load_module("test_raytracing.slang")) | ||
| module.trace.ray_tracing( | ||
| hit_groups=[{"hit_group_name": "hit_group", "closest_hit_entry_point": "closest_hit"}], | ||
| miss_entry_points=["miss"], | ||
| max_recursion=1, | ||
| max_ray_payload_size=12, | ||
| )(tid=spy.call_id(), tlas=tlas, _result=tensor) | ||
| data = tensor.to_numpy() | ||
| # spy.tev.show(spy.Bitmap(data)) | ||
| assert np.allclose(data[0, 0, :], [0, 0, 0], atol=0.01) | ||
| assert np.allclose(data[0, 63, :], [1, 0, 0], atol=0.01) | ||
| assert np.allclose(data[63, 0, :], [0, 1, 0], atol=0.01) | ||
| assert np.allclose(data[63, 63, :], [1, 0, 1], atol=0.01) | ||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v", "-s"]) |
Sorry, the diff of this file is not supported yet
+2
-2
| Metadata-Version: 2.4 | ||
| Name: slangpy | ||
| Version: 0.34.0 | ||
| Version: 0.35.0 | ||
| Summary: Easily call Slang functions and integrate with PyTorch auto diff directly from Python. | ||
@@ -103,5 +103,5 @@ Author-email: Simon Kallweit <skallweit@nvidia.com>, Chris Cummings <chriscummings@nvidia.com>, Benedikt Bitterli <bbitterli@nvidia.com>, Sai Bangaru <sbangaru@nvidia.com>, Yong He <yhe@nvidia.com> | ||
| note = {https://github.com/shader-slang/slangpy}, | ||
| version = {0.34.0}, | ||
| version = {0.35.0}, | ||
| year = 2025 | ||
| } | ||
| ``` |
+30
-28
| slangpy/__init__.py,sha256=PsbtOD5NRRgDcN2C7SbPnePNpAzce326xQM9tZhBT9M,1948 | ||
| slangpy/__init__.pyi,sha256=L-SxD_v3BYdqsi4I3Uq6qtDXIThLa1pe68fVQfnuOBE,192810 | ||
| slangpy/libsgl.dylib,sha256=Bi8bqjreB8AgnP9PHk62KxeyydQiLf76e8Wq4jgt6EY,8336864 | ||
| slangpy/libslang-glsl-module.dylib,sha256=Hqm_WyMZpzM8m3-ARtacg5QFtECzV_ysBHhrx0ADOTE,1292880 | ||
| slangpy/libslang-glslang.dylib,sha256=9mH1jRPBHOU-RevJHtFJUyXCIYJkKZCFfSwYb1St_9o,7876544 | ||
| slangpy/libslang-rhi.a,sha256=bturAvxqevJ7j3wpuQ7mlXwAhGp1WS0mSUzkoHH1RXE,74103624 | ||
| slangpy/libslang-rt.dylib,sha256=9BKh0LwHqoOOPSYKittTf1PbtlObgoiIKTDjP5pmnaw,249056 | ||
| slangpy/libslang.dylib,sha256=xnn0H5N-CT0C-M3-hQS7Sgr5jX9TdfQm4SNVQSlc3yY,19208720 | ||
| slangpy/slangpy_ext.cpython-310-darwin.so,sha256=HIo0pZthekBN5_vCo83eDQYsIV_v9hhFw09mqT7uZQg,6608624 | ||
| slangpy/__init__.pyi,sha256=uitnhvr1sP-WmPkWxGCbQfZZhXPFNVqP2S1sHqIhJbE,193026 | ||
| slangpy/libsgl.dylib,sha256=fe5P5-act_QWXSUr83y1b4TnSlIFkltDGx-j1r__L2g,8336896 | ||
| slangpy/libslang-glsl-module.dylib,sha256=M-suEgdNUFMlBuSlG5-wUm1hsUcu0r85QcYAhYD2Vf8,1292880 | ||
| slangpy/libslang-glslang.dylib,sha256=0941UZI3GrFNd-JilmnaEsMnglDmndBgvOZFkAnZ948,8033520 | ||
| slangpy/libslang-rhi.a,sha256=FG1-ksx-jWgFvpR3ScrEU4wBU-d_el91yfocrLG8PT0,74108576 | ||
| slangpy/libslang-rt.dylib,sha256=C7DXKhnoycvr4OiHKSSfWXgqtML4mnyg50Lm_gsZQaQ,249968 | ||
| slangpy/libslang.dylib,sha256=3wuOwjodITu-7CPh7XtRX5S7q7PIcm1hGgz8HBO3d8Q,19373712 | ||
| slangpy/slangpy_ext.cpython-310-darwin.so,sha256=ghED3AR43maqxD68k0YbXI5-OwCGUL1CAvx0emmB6ZI,6630720 | ||
| slangpy/benchmarks/conftest.py,sha256=rAYyblUkfrklcEe_TbJ0UuSAzQ8b0pmAbrWG2_s1nrg,438 | ||
@@ -34,11 +34,11 @@ slangpy/benchmarks/test_benchmark_interop.py,sha256=7xkivr7F1sD7_VRUVivVy125whirpZkbuttauURqgDw,15467 | ||
| slangpy/core/__init__.py,sha256=TJYDzVSv8gf6xRO-8P6Fg8BpYxp3rjdxiUgt18yd_R4,94 | ||
| slangpy/core/calldata.py,sha256=z_suSAbkfptrWhIB0tBQgoN_AOfXBjSPg_HQG5GOfd8,18003 | ||
| slangpy/core/callsignature.py,sha256=UqZgXTkXT7ObkB7SOi2iDlDd3FQUkSu_dEtjO79TOY8,25905 | ||
| slangpy/core/dispatchdata.py,sha256=J0yHRgwYJno6n9WStq3mbdJLJuSQiZmcrdsO52cBMJ0,10739 | ||
| slangpy/core/calldata.py,sha256=oVK6IexBHGWsp7Qffx1wS75QymE8V57HCRCKC-bmxUM,21458 | ||
| slangpy/core/callsignature.py,sha256=_ljcrJe1Z8E6ckaWaUaDkWuauP3ZGzBl3EEMlGPT6XU,26970 | ||
| slangpy/core/dispatchdata.py,sha256=S176OEjLyzVovpkomRllpxVproTy4KttSRnU38ie5Z0,10955 | ||
| slangpy/core/enums.py,sha256=lLLVsfRqLWNMDOE7eaxOyy6t6KNzbXjv2JsvmepltU0,211 | ||
| slangpy/core/function.py,sha256=dnor5p_ANTHnpGS7tUYRT8n6OzVi6abEx5V__ODW0iw,20499 | ||
| slangpy/core/function.py,sha256=25UwWJ2-Hjesiyk0GextPawvmnHLV5vgsMEr6QN6pDA,23663 | ||
| slangpy/core/instance.py,sha256=Q-Jjt1rU9p1-j_aFQA01DdWCqKaPGfn3C3xgZ_QF8eU,4535 | ||
| slangpy/core/jupyter.py,sha256=iRB4SgmlI3qCk0itlUUGYjB--BwvloKpvB8ZT7UOzVY,9676 | ||
| slangpy/core/logging.py,sha256=LW5fu_5neXYMwnsF8dYXalmQGZJl6j3xWTfKgJ_xZ0k,8688 | ||
| slangpy/core/module.py,sha256=wt2uQYR0Y-Yl1tyfDzsd6sMEAOeNJMWwXis3tftjjRI,8134 | ||
| slangpy/core/module.py,sha256=4q9evVKFmNpB81mEUiiyOoV8rGLrJkId1tGQd6K4LXc,8215 | ||
| slangpy/core/native.py,sha256=SpEiYl6WrtB8c1n7CBUVcdsAmnmqpN37DS3JbtPMEqU,151 | ||
@@ -103,3 +103,3 @@ slangpy/core/packedarg.py,sha256=8yD7ZHZ3I7J2LHGVKdMDpBL7UK-11XrJ5bU6UCsScYY,2060 | ||
| slangpy/experimental/gridarg.py,sha256=ZRTZCmN3dQLS0oTZ7JPf9CK8GpEpC489XkaQG4wUyj4,4207 | ||
| slangpy/include/slang-deprecated.h,sha256=Y3lFxhYrnatDzVFybqurn_r2k3hB_NjkM3zU4gQczDU,69513 | ||
| slangpy/include/slang-deprecated.h,sha256=4HgvR46cTImMBvQI3u2Weukaz9hykr0P_xLVG1ZZgSA,69655 | ||
| slangpy/include/slang-image-format-defs.h,sha256=Na2baM-oiRV7K_Sk7zkmO0rV8Pdg_-5mrKA0OOg2O80,2421 | ||
@@ -109,4 +109,4 @@ slangpy/include/slang-rhi-config.h,sha256=Em9yoqMRi3H21r1rWhJgTwSbWUk88IEQzrQ-ot2Ppf0,403 | ||
| slangpy/include/slang-user-config.h,sha256=-DTfKSDXKnCWR_NU6n96RYbQyYmM5ms2IFQH8UWNn7Y,151 | ||
| slangpy/include/slang.h,sha256=x0nFtGeAM_HmiReKpSjzPIPdyOTnWRAXkNt-KCR6paE,178544 | ||
| slangpy/include/sgl/sgl.h,sha256=R3uIVJZJM5Q-CUtCv9KDCvvpEjyLerKONwt5jpT7aTU,610 | ||
| slangpy/include/slang.h,sha256=sQbZxKt6WhP8F8wvU8EnyMzltucfmIWgQ_7tdZnhRvM,179036 | ||
| slangpy/include/sgl/sgl.h,sha256=XEOQ9UxGba9xzNvoyrUDI7QNRAmxa5SOJUF8sWUWtKM,610 | ||
| slangpy/include/sgl/sgl_pch.h,sha256=SD1ALBpwwkcMlCnU_V8nHdSvS6-NZIXZOv5Rd6_r5z8,131 | ||
@@ -121,3 +121,3 @@ slangpy/include/sgl/app/app.h,sha256=tyvkNXNrsvsYWL98uSxVKcYesdiJ2ZRqdwJpfw9dbwA,2809 | ||
| slangpy/include/sgl/core/error.h,sha256=o-Mi9Fo-j3sDX1BZnIljNdLQtRylKiJuPYr3AQvCCS0,6727 | ||
| slangpy/include/sgl/core/file_stream.h,sha256=KvaJHbPP-WSReT9k7DeMHtiVbRNkyzuvevkqAcdVlmI,1787 | ||
| slangpy/include/sgl/core/file_stream.h,sha256=7JXL14o9sdLC4DJmKoyJ5jWyN6wIBdK1ALudklAXYQs,1509 | ||
| slangpy/include/sgl/core/file_system_watcher.h,sha256=1D_gaJh-GtPUNEiBySx46VwPZw4rvjeh2e_qE3tUlyo,3968 | ||
@@ -140,3 +140,3 @@ slangpy/include/sgl/core/format.h,sha256=WP8Qy7Y8sfnJFiHqodsmzPsi8W5VRlQM5-8gpi0WNqU,903 | ||
| slangpy/include/sgl/core/static_vector.h,sha256=Giyz8df4Cz0apEfaAOyPHx51pd8X-TO_AiDeq4fdP14,3070 | ||
| slangpy/include/sgl/core/stream.h,sha256=RNwFNcYmq6lPC9-zZprqSySyV9UabevNuuDQgWHZ5Fs,1507 | ||
| slangpy/include/sgl/core/stream.h,sha256=mdrcEPnkxZSdrzI9pjoqCvEsGFkOX0KxJsRV5UlvIJs,1786 | ||
| slangpy/include/sgl/core/string.h,sha256=azYJO32kfk9ThHoLsEjQMiTiYUPt-2Uadve1mEvA0v8,8223 | ||
@@ -188,3 +188,3 @@ slangpy/include/sgl/core/thread.h,sha256=A9W4s8uQ2_BAxL-yUZlGZaIrWSngO9bMUZiGzyK7qpE,670 | ||
| slangpy/include/sgl/math/matrix.h,sha256=ipSvS3MPZqHHKbahaG4OSowi4WxUeU2WqV45sDCQQQo,143 | ||
| slangpy/include/sgl/math/matrix_math.h,sha256=uYo82bL_D2IKntAQ-c5TipUQ7YFOkR9NO7t3LR-07zo,24114 | ||
| slangpy/include/sgl/math/matrix_math.h,sha256=jazXRhe9NoPrwrXQnYG_t8yFmGvYlMrRsNt28CUE7EQ,24382 | ||
| slangpy/include/sgl/math/matrix_types.h,sha256=n-Yg6I3EmZIsWoKxM6xLFW9i0Kgv4axTDdy7tO5NlSI,5645 | ||
@@ -216,3 +216,3 @@ slangpy/include/sgl/math/quaternion.h,sha256=SjZfbLOa0GI7eGIyyi0exmcyNIG6L9T1J3A92nYLCek,151 | ||
| slangpy/math/__init__.py,sha256=iasSej-enqU0bZTocJyJYiXTfADJ9sd89NwHtF1PDI0,58 | ||
| slangpy/math/__init__.pyi,sha256=x1kH1ceRfEIZJ5fNipmAI-UuilqP3wRA133pAg9Og-4,113876 | ||
| slangpy/math/__init__.pyi,sha256=Qp57jIhdbzO3YpMu6wyo1CfVDi8HJRdu8xJL1Ysx7Xs,113931 | ||
| slangpy/platform/__init__.py,sha256=iasSej-enqU0bZTocJyJYiXTfADJ9sd89NwHtF1PDI0,58 | ||
@@ -243,3 +243,3 @@ slangpy/platform/__init__.pyi,sha256=Xew0oHwYHVdPnScU_wCGmf1dVnfjPhXDfdlzycqjL9U,2599 | ||
| slangpy/slangpy/__init__.py,sha256=iasSej-enqU0bZTocJyJYiXTfADJ9sd89NwHtF1PDI0,58 | ||
| slangpy/slangpy/__init__.pyi,sha256=8bJzuPxdgETTHbvRSpvVldDk37T88E7LDjK43LLJILU,22329 | ||
| slangpy/slangpy/__init__.pyi,sha256=ETgClZZDNoMZqtHrORVjG7_J34Vp8EzyUpjBwOdpIEM,22491 | ||
| slangpy/testing/__init__.py,sha256=iasSej-enqU0bZTocJyJYiXTfADJ9sd89NwHtF1PDI0,58 | ||
@@ -323,3 +323,3 @@ slangpy/testing/helpers.py,sha256=aQBRM5wYX_MW11cqWzd_HBnjekS28KsHTEaOVOoaXQM,15087 | ||
| slangpy/tests/slangpy_tests/test_buffers.py,sha256=f0bNcEB_E20QS4ag505SzbKSk_HtXVR6oDAlcUV6d_M,1265 | ||
| slangpy/tests/slangpy_tests/test_caching.py,sha256=lw5uauQyGuc7Mk1ggy_Ps96AJQ3ttrCTeMbK-F8DfpM,2045 | ||
| slangpy/tests/slangpy_tests/test_caching.py,sha256=O3jBBSeuMBkYNkRZdpFlISnLA8C-nxo1t0TOIpmRkDM,2029 | ||
| slangpy/tests/slangpy_tests/test_call_group_integrations.py,sha256=Zid7TVEjFK7XkbFsLSNDLFPp1tt5pQ3FAxX29XudA5E,32480 | ||
@@ -349,2 +349,4 @@ slangpy/tests/slangpy_tests/test_call_groups.py,sha256=_1BD8Upp-qeCbUsI--NaZ2zLKZJvjxZAjH3kBEd25RY,26207 | ||
| slangpy/tests/slangpy_tests/test_raw_dispatch.py,sha256=pj6ut45b9I1mRrBTntlyp8eCbaruY3JWVGurFoJ1Q_g,6140 | ||
| slangpy/tests/slangpy_tests/test_raytracing.py,sha256=Dww27b51bZIBxD6hvzoi5Sxrzlqcxk8E0dW_isrVOTs,4720 | ||
| slangpy/tests/slangpy_tests/test_raytracing.slang,sha256=-q5YFKPv144lQxMc2j1DCTJFq1rk0c1QJxtjCF61Cxw,903 | ||
| slangpy/tests/slangpy_tests/test_reflection2.py,sha256=qO6EIWdnCYyHa-TQIU5m31WiU15Ga2oLp-41qjBTEUc,12989 | ||
@@ -388,3 +390,3 @@ slangpy/tests/slangpy_tests/test_return_types.py,sha256=MwnQy8791jYrelOFt8-9ImEL6H8zUUUfJRJ_Z2kXXz4,3731 | ||
| slangpy/types/__init__.py,sha256=fmSHX1CNHOdyiFQ6Q2QVsTHghJiADss_Wf9Y23SruNU,731 | ||
| slangpy/types/buffer.py,sha256=jpoIB5oL3jJE5MhtpkiMAdpF_AMR8e3bYEEPVjN8duM,16323 | ||
| slangpy/types/buffer.py,sha256=3cJBBREBwzLtiAIHaXn6woL3JCY-YAseInt3h2Npiu4,16584 | ||
| slangpy/types/callidarg.py,sha256=hIPIfFw17TbzL5rpmBcpZN5lsQK4ge66ZAVwO9wc-yA,2623 | ||
@@ -400,6 +402,6 @@ slangpy/types/diffpair.py,sha256=K9C4kplAdFJbjeDoe6f7JYbM43Cs3S60ygjGonrrVqg,1627 | ||
| slangpy/ui/__init__.pyi,sha256=6oKnW1qupqQCPSe82znQzssxR5mw6zAuM4mGr3CaSvQ,30512 | ||
| slangpy-0.34.0.dist-info/licenses/LICENSE,sha256=tzjZi3clQUtTTBGaxmEmh2uMvUCqiuSCENsZBxS4FnA,1490 | ||
| slangpy-0.34.0.dist-info/METADATA,sha256=-30QRMsOk7W4oCUsmiO_USKf5NCLRt0aFy6SGUYfPnU,4323 | ||
| slangpy-0.34.0.dist-info/WHEEL,sha256=QbwVjH-TlueSdsOKSic424jsJD30r52gknYt4wXVwF0,109 | ||
| slangpy-0.34.0.dist-info/top_level.txt,sha256=Wo7_Eny8d-MI3ZHT-XHGGMEKKIK38FlAmUCz8AffxA0,8 | ||
| slangpy-0.34.0.dist-info/RECORD,, | ||
| slangpy-0.35.0.dist-info/licenses/LICENSE,sha256=tzjZi3clQUtTTBGaxmEmh2uMvUCqiuSCENsZBxS4FnA,1490 | ||
| slangpy-0.35.0.dist-info/METADATA,sha256=Ue9H520aYJVwxkzlynefSeDPFT0wbm1O4nvFtvwhM94,4323 | ||
| slangpy-0.35.0.dist-info/WHEEL,sha256=QbwVjH-TlueSdsOKSic424jsJD30r52gknYt4wXVwF0,109 | ||
| slangpy-0.35.0.dist-info/top_level.txt,sha256=Wo7_Eny8d-MI3ZHT-XHGGMEKKIK38FlAmUCz8AffxA0,8 | ||
| slangpy-0.35.0.dist-info/RECORD,, |
+74
-14
@@ -135,4 +135,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| # Set call data mode based on device type | ||
| if build_info.module.device.info.type == DeviceType.cuda: | ||
| # Set call data mode based on device and pipeline type | ||
| if ( | ||
| build_info.module.device.info.type == DeviceType.cuda | ||
| and build_info.pipeline_type == PipelineType.compute | ||
| ): | ||
| self.call_data_mode = CallDataMode.entry_point | ||
@@ -312,5 +315,8 @@ else: | ||
| # Check if we've already built this module. | ||
| if hash in build_info.module.compute_pipeline_cache: | ||
| if hash in build_info.module.pipeline_cache: | ||
| # Get pipeline from cache if we have | ||
| self.compute_pipeline = build_info.module.compute_pipeline_cache[hash] | ||
| self.pipeline = build_info.module.pipeline_cache[hash] | ||
| # Get shader table from cache if the pipeline is a raytracing pipeline | ||
| if build_info.pipeline_type == PipelineType.ray_tracing: | ||
| self.shader_table = build_info.module.shader_table_cache[hash] | ||
| self.device = build_info.module.device | ||
@@ -325,15 +331,69 @@ self.log_debug(f" Found cached pipeline with hash {hash}") | ||
| module = session.load_module_from_source(hash, code) | ||
| ep = module.entry_point(f"compute_main", type_conformances) | ||
| opts = SlangLinkOptions() | ||
| opts.dump_intermediates = _DUMP_SLANG_INTERMEDIATES | ||
| opts.dump_intermediates_prefix = sanitized | ||
| program = session.link_program( | ||
| [module, build_info.module.device_module] + build_info.module.link, | ||
| [ep], | ||
| opts, | ||
| ) | ||
| self.compute_pipeline = device.create_compute_pipeline( | ||
| program, defer_target_compilation=True | ||
| ) | ||
| build_info.module.compute_pipeline_cache[hash] = self.compute_pipeline | ||
| if build_info.pipeline_type == PipelineType.compute: | ||
| # Create compute pipeline | ||
| ep = module.entry_point(f"compute_main", type_conformances) | ||
| program = session.link_program( | ||
| [module, build_info.module.device_module] + build_info.module.link, | ||
| [ep], | ||
| opts, | ||
| ) | ||
| self.pipeline = device.create_compute_pipeline( | ||
| program, defer_target_compilation=True | ||
| ) | ||
| build_info.module.pipeline_cache[hash] = self.pipeline | ||
| elif build_info.pipeline_type == PipelineType.ray_tracing: | ||
| # Create ray tracing pipeline | ||
| eps = [module.entry_point(f"raygen_main", type_conformances)] | ||
| hit_group_names: list[str] = [] | ||
| for hit_group in build_info.ray_tracing_hit_groups: | ||
| hit_group_names.append(hit_group.hit_group_name) | ||
| if hit_group.closest_hit_entry_point != "": | ||
| eps.append( | ||
| build_info.module.device_module.entry_point( | ||
| hit_group.closest_hit_entry_point | ||
| ) | ||
| ) | ||
| if hit_group.any_hit_entry_point != "": | ||
| eps.append( | ||
| build_info.module.device_module.entry_point( | ||
| hit_group.any_hit_entry_point | ||
| ) | ||
| ) | ||
| if hit_group.intersection_entry_point != "": | ||
| eps.append( | ||
| build_info.module.device_module.entry_point( | ||
| hit_group.intersection_entry_point | ||
| ) | ||
| ) | ||
| for miss_entry_point in build_info.ray_tracing_miss_entry_points: | ||
| eps.append(build_info.module.device_module.entry_point(miss_entry_point)) | ||
| program = session.link_program( | ||
| [module, build_info.module.device_module] + build_info.module.link, | ||
| eps, | ||
| opts, | ||
| ) | ||
| self.pipeline = device.create_ray_tracing_pipeline( | ||
| program, | ||
| hit_groups=build_info.ray_tracing_hit_groups, | ||
| max_recursion=build_info.ray_tracing_max_recursion, | ||
| max_ray_payload_size=build_info.ray_tracing_max_ray_payload_size, | ||
| max_attribute_size=build_info.ray_tracing_max_attribute_size, | ||
| flags=build_info.ray_tracing_flags, | ||
| defer_target_compilation=True, | ||
| ) | ||
| build_info.module.pipeline_cache[hash] = self.pipeline | ||
| self.shader_table = device.create_shader_table( | ||
| program, | ||
| ray_gen_entry_points=["raygen_main"], | ||
| miss_entry_points=build_info.ray_tracing_miss_entry_points, | ||
| hit_group_names=hit_group_names, | ||
| callable_entry_points=build_info.ray_tracing_callable_entry_points, | ||
| ) | ||
| build_info.module.shader_table_cache[hash] = self.shader_table | ||
| else: | ||
| raise RuntimeError("Unknown pipeline type") | ||
| self.device = device | ||
@@ -340,0 +400,0 @@ self.log_debug(f" Build succesful") |
@@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| from slangpy.core.native import AccessType, CallMode, CallDataMode, NativeMarshall | ||
| from slangpy.core.function import PipelineType | ||
| import slangpy.bindings.typeregistry as tr | ||
| import slangpy.reflection as slr | ||
| from slangpy import ModifierID, TypeReflection, DeviceType | ||
| from slangpy import ModifierID, TypeReflection | ||
| from slangpy.bindings.marshall import Marshall, BindContext, ReturnContext | ||
@@ -601,20 +602,32 @@ from slangpy.bindings.boundvariable import ( | ||
| # Generate the main function | ||
| cg.kernel.append_line('[shader("compute")]') | ||
| if call_group_size != 1: | ||
| cg.kernel.append_line(f"[numthreads({call_group_size}, 1, 1)]") | ||
| if build_info.pipeline_type == PipelineType.compute: | ||
| cg.kernel.append_line('[shader("compute")]') | ||
| if call_group_size != 1: | ||
| cg.kernel.append_line(f"[numthreads({call_group_size}, 1, 1)]") | ||
| else: | ||
| cg.kernel.append_line("[numthreads(32, 1, 1)]") | ||
| # Note: While flat_call_thread_id is 3-dimensional, we consider it "flat" and 1-dimensional because of the | ||
| # true call group shape of [x, 1, 1] and only use the first dimension for the call thread id. | ||
| if is_entry_point: | ||
| cg.kernel.append_line( | ||
| "void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, int3 flat_call_group_id: SV_GroupID, int flat_call_group_thread_id: SV_GroupIndex, uniform CallData call_data)" | ||
| ) | ||
| else: | ||
| cg.kernel.append_line( | ||
| "void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, int3 flat_call_group_id: SV_GroupID, int flat_call_group_thread_id: SV_GroupIndex)" | ||
| ) | ||
| elif build_info.pipeline_type == PipelineType.ray_tracing: | ||
| cg.kernel.append_line('[shader("raygen")]') | ||
| if is_entry_point: | ||
| cg.kernel.append_line("void raygen_main(uniform CallData call_data)") | ||
| else: | ||
| cg.kernel.append_line("void raygen_main()") | ||
| else: | ||
| cg.kernel.append_line("[numthreads(32, 1, 1)]") | ||
| raise RuntimeError(f"Unknown pipeline type: {build_info.pipeline_type}") | ||
| # Note: While flat_call_thread_id is 3-dimensional, we consider it "flat" and 1-dimensional because of the | ||
| # true call group shape of [x, 1, 1] and only use the first dimension for the call thread id. | ||
| if is_entry_point: | ||
| cg.kernel.append_line( | ||
| "void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, int3 flat_call_group_id: SV_GroupID, int flat_call_group_thread_id: SV_GroupIndex, CallData call_data)" | ||
| ) | ||
| else: | ||
| cg.kernel.append_line( | ||
| "void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, int3 flat_call_group_id: SV_GroupID, int flat_call_group_thread_id: SV_GroupIndex)" | ||
| ) | ||
| cg.kernel.begin_block() | ||
| if build_info.pipeline_type == PipelineType.ray_tracing: | ||
| cg.kernel.append_statement("int3 flat_call_thread_id = DispatchRaysIndex();") | ||
| cg.kernel.append_statement("if (any(flat_call_thread_id >= call_data._thread_count)) return") | ||
@@ -628,10 +641,18 @@ | ||
| if call_data_len > 0: | ||
| cg.kernel.append_line( | ||
| f""" | ||
| if (!init_thread_local_call_shape_info(flat_call_group_thread_id, | ||
| flat_call_group_id, flat_call_thread_id, call_data._grid_stride, | ||
| call_data._grid_dim, call_data._call_dim)) | ||
| return;""" | ||
| ) | ||
| if build_info.pipeline_type == PipelineType.compute: | ||
| cg.kernel.append_line( | ||
| f""" | ||
| if (!init_thread_local_call_shape_info(flat_call_group_thread_id, | ||
| flat_call_group_id, flat_call_thread_id, call_data._grid_stride, | ||
| call_data._grid_dim, call_data._call_dim)) | ||
| return;""" | ||
| ) | ||
| elif build_info.pipeline_type == PipelineType.ray_tracing: | ||
| cg.kernel.append_line( | ||
| f""" | ||
| if (!init_thread_local_call_shape_info(0, | ||
| uint3(0), flat_call_thread_id, call_data._grid_stride, | ||
| call_data._grid_dim, call_data._call_dim)) | ||
| return;""" | ||
| ) | ||
| context_args += ", CallShapeInfo::get_call_id().shape" | ||
@@ -638,0 +659,0 @@ |
@@ -12,3 +12,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| from slangpy import CommandEncoder, ShaderCursor, SlangLinkOptions, uint3, DeviceType | ||
| from slangpy import ( | ||
| CommandEncoder, | ||
| ComputePipeline, | ||
| ShaderCursor, | ||
| SlangLinkOptions, | ||
| uint3, | ||
| DeviceType, | ||
| ) | ||
| from slangpy.core.native import NativeCallRuntimeOptions | ||
@@ -178,5 +185,8 @@ from slangpy.bindings.marshall import BindContext | ||
| # Check if we've already built this module. | ||
| if hash in build_info.module.compute_pipeline_cache: | ||
| if hash in build_info.module.pipeline_cache: | ||
| # Get pipeline from cache if we have | ||
| self.compute_pipeline = build_info.module.compute_pipeline_cache[hash] | ||
| pipeline = build_info.module.pipeline_cache[hash] | ||
| if not isinstance(pipeline, ComputePipeline): | ||
| raise RuntimeError("Pipeline cache entry is not a ComputePipeline") | ||
| self.compute_pipeline = pipeline | ||
| self.device = build_info.module.device | ||
@@ -183,0 +193,0 @@ else: |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union, cast | ||
| from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union, cast, Sequence | ||
| from enum import Enum | ||
@@ -13,3 +14,12 @@ from slangpy.core.native import ( | ||
| from slangpy.reflection import SlangFunction, SlangType | ||
| from slangpy import CommandEncoder, TypeConformance, uint3, Logger, NativeHandle, NativeHandleType | ||
| from slangpy import ( | ||
| CommandEncoder, | ||
| TypeConformance, | ||
| uint3, | ||
| Logger, | ||
| NativeHandle, | ||
| NativeHandleType, | ||
| RayTracingPipelineFlags, | ||
| HitGroupDesc, | ||
| ) | ||
| from slangpy.slangpy import Shape | ||
@@ -22,2 +32,3 @@ from slangpy.bindings.typeregistry import PYTHON_SIGNATURES | ||
| from slangpy.core.struct import Struct | ||
| from slangpy import HitGroupDescParam | ||
@@ -44,2 +55,7 @@ ENABLE_CALLDATA_CACHE = True | ||
| class PipelineType(Enum): | ||
| compute = 0 | ||
| ray_tracing = 1 | ||
| class FunctionBuildInfo: | ||
@@ -66,2 +82,10 @@ def __init__(self) -> None: | ||
| self.call_group_shape: Optional[Shape] = None | ||
| self.pipeline_type: PipelineType = PipelineType.compute | ||
| self.ray_tracing_hit_groups: list[HitGroupDesc] = [] | ||
| self.ray_tracing_miss_entry_points: list[str] = [] | ||
| self.ray_tracing_callable_entry_points: list[str] = [] | ||
| self.ray_tracing_max_recursion: int = 0 | ||
| self.ray_tracing_max_ray_payload_size: int = 0 | ||
| self.ray_tracing_max_attribute_size: int = 8 | ||
| self.ray_tracing_flags: RayTracingPipelineFlags = RayTracingPipelineFlags.none | ||
@@ -160,2 +184,26 @@ | ||
| def ray_tracing( | ||
| self, | ||
| hit_groups: Sequence["HitGroupDescParam"], | ||
| miss_entry_points: Sequence[str] = [], | ||
| callable_entry_points: Sequence[str] = [], | ||
| max_recursion: int = 1, | ||
| max_ray_payload_size: int = 32, | ||
| max_attribute_size: int = 8, | ||
| flags: RayTracingPipelineFlags = RayTracingPipelineFlags.none, | ||
| ): | ||
| """ | ||
| Specify the ray tracing pipeline configuration. | ||
| """ | ||
| return FunctionNodeRayTracing( | ||
| self, | ||
| hit_groups, | ||
| miss_entry_points, | ||
| callable_entry_points, | ||
| max_recursion, | ||
| max_ray_payload_size, | ||
| max_attribute_size, | ||
| flags, | ||
| ) | ||
| @property | ||
@@ -463,2 +511,41 @@ def bwds(self): | ||
| class FunctionNodeRayTracing(FunctionNode): | ||
| def __init__( | ||
| self, | ||
| parent: NativeFunctionNode, | ||
| hit_groups: Sequence["HitGroupDescParam"], | ||
| miss_entry_points: Sequence[str], | ||
| callable_entry_points: Sequence[str], | ||
| max_recursion: int, | ||
| max_ray_payload_size: int, | ||
| max_attribute_size: int, | ||
| flags: RayTracingPipelineFlags, | ||
| ): | ||
| super().__init__( | ||
| parent, | ||
| FunctionNodeType.ray_tracing, | ||
| { | ||
| "hit_groups": [HitGroupDesc(hit_group) for hit_group in hit_groups], # type: ignore | ||
| "miss_entry_points": list(miss_entry_points), | ||
| "callable_entry_points": list(callable_entry_points), | ||
| "max_recursion": max_recursion, | ||
| "max_ray_payload_size": max_ray_payload_size, | ||
| "max_attribute_size": max_attribute_size, | ||
| "flags": flags, | ||
| }, | ||
| ) | ||
| self.slangpy_signature = f"({hit_groups}, {miss_entry_points}, {callable_entry_points}, {max_recursion}, {max_ray_payload_size}, {max_attribute_size}, {flags})" | ||
| def _populate_build_info(self, info: FunctionBuildInfo): | ||
| d = cast(dict[str, Any], self._native_data) | ||
| info.pipeline_type = PipelineType.ray_tracing | ||
| info.ray_tracing_hit_groups = d["hit_groups"] | ||
| info.ray_tracing_miss_entry_points = d["miss_entry_points"] | ||
| info.ray_tracing_callable_entry_points = d["callable_entry_points"] | ||
| info.ray_tracing_max_recursion = d["max_recursion"] | ||
| info.ray_tracing_max_ray_payload_size = d["max_ray_payload_size"] | ||
| info.ray_tracing_max_attribute_size = d["max_attribute_size"] | ||
| info.ray_tracing_flags = d["flags"] | ||
| class FunctionNodeBwds(FunctionNode): | ||
@@ -465,0 +552,0 @@ def __init__(self, parent: NativeFunctionNode) -> None: |
@@ -7,3 +7,3 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| from slangpy import ComputePipeline, SlangModule, Device, Logger | ||
| from slangpy import Pipeline, ShaderTable, SlangModule, Device, Logger | ||
| from slangpy.core.native import NativeCallDataCache | ||
@@ -77,3 +77,4 @@ from slangpy.reflection import SlangProgramLayout | ||
| self.dispatch_data_cache: dict[str, "DispatchData"] = {} | ||
| self.compute_pipeline_cache: dict[str, ComputePipeline] = {} | ||
| self.pipeline_cache: dict[str, Pipeline] = {} | ||
| self.shader_table_cache: dict[str, ShaderTable] = {} | ||
| self.logger: Optional[Logger] = None | ||
@@ -215,3 +216,4 @@ | ||
| self.dispatch_data_cache = {} | ||
| self.compute_pipeline_cache = {} | ||
| self.pipeline_cache = {} | ||
| self.shader_table_cache = {} | ||
| self._attr_cache = {} | ||
@@ -218,0 +220,0 @@ |
@@ -15,16 +15,3 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| class SGL_API EOFException : public std::runtime_error { | ||
| public: | ||
| EOFException(const std::string& what, size_t gcount) | ||
| : std::runtime_error(what) | ||
| , m_gcount(gcount) | ||
| { | ||
| } | ||
| size_t gcount() const { return m_gcount; } | ||
| private: | ||
| size_t m_gcount; | ||
| }; | ||
| class SGL_API FileStream : public Stream { | ||
@@ -31,0 +18,0 @@ SGL_OBJECT(FileStream) |
@@ -10,2 +10,16 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| class SGL_API EOFException : public std::runtime_error { | ||
| public: | ||
| EOFException(const std::string& what, size_t gcount) | ||
| : std::runtime_error(what) | ||
| , m_gcount(gcount) | ||
| { | ||
| } | ||
| size_t gcount() const { return m_gcount; } | ||
| private: | ||
| size_t m_gcount; | ||
| }; | ||
| /** | ||
@@ -12,0 +26,0 @@ * \brief Base class for all stream objects. |
@@ -714,2 +714,13 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| template<floating_point T> | ||
| [[nodiscard]] inline matrix<T, 4, 4> matrix_4x4_from_3x4(const matrix<T, 3, 4>& m) | ||
| { | ||
| matrix<T, 4, 4> result; | ||
| for (int r = 0; r < 3; ++r) { | ||
| result.set_row(r, m.get_row(r)); | ||
| } | ||
| result[3][3] = T(1); | ||
| return result; | ||
| } | ||
| template<typename T, int R, int C> | ||
@@ -716,0 +727,0 @@ [[nodiscard]] std::string to_string(const matrix<T, R, C>& m) |
@@ -8,3 +8,3 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| #define SGL_VERSION_MAJOR 0 | ||
| #define SGL_VERSION_MINOR 34 | ||
| #define SGL_VERSION_MINOR 35 | ||
| #define SGL_VERSION_PATCH 0 | ||
@@ -11,0 +11,0 @@ |
@@ -336,9 +336,16 @@ from collections.abc import Mapping, Sequence | ||
| @property | ||
| def compute_pipeline(self) -> slangpy.ComputePipeline: | ||
| def pipeline(self) -> slangpy.Pipeline: | ||
| """N/A""" | ||
| @compute_pipeline.setter | ||
| def compute_pipeline(self, arg: slangpy.ComputePipeline, /) -> None: ... | ||
| @pipeline.setter | ||
| def pipeline(self, arg: slangpy.Pipeline, /) -> None: ... | ||
| @property | ||
| def shader_table(self) -> slangpy.ShaderTable: | ||
| """N/A""" | ||
| @shader_table.setter | ||
| def shader_table(self, arg: slangpy.ShaderTable, /) -> None: ... | ||
| @property | ||
| def call_dimensionality(self) -> int: | ||
@@ -712,2 +719,4 @@ """N/A""" | ||
| ray_tracing = 5 | ||
| class NativeFunctionNode(NativeObject): | ||
@@ -714,0 +723,0 @@ def __init__(self, parent: NativeFunctionNode | None, type: FunctionNodeType, data: object | None) -> None: |
@@ -58,3 +58,3 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| assert float_float_cd != mapped_float_float_cd | ||
| assert float_float_cd.compute_pipeline == mapped_float_float_cd.compute_pipeline | ||
| assert float_float_cd.pipeline == mapped_float_float_cd.pipeline | ||
@@ -61,0 +61,0 @@ |
@@ -73,2 +73,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| def _hot_reload_lookup_module(device: Device): | ||
| if device in global_lookup_modules: | ||
| dummy_module = device.load_module_from_source("slangpy_layout", 'import "slangpy";') | ||
| global_lookup_modules[device].on_hot_reload(dummy_module.layout) | ||
| def get_lookup_module(device: Device) -> SlangProgramLayout: | ||
@@ -78,3 +84,3 @@ if device not in global_lookup_modules: | ||
| device.register_device_close_callback(_on_device_close) | ||
| device.register_shader_hot_reload_callback(lambda _: _load_lookup_module(device)) | ||
| device.register_shader_hot_reload_callback(lambda _: _hot_reload_lookup_module(device)) | ||
@@ -81,0 +87,0 @@ return global_lookup_modules[device] |
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.