From 63934a9a552cbb255845190a079ebd90f48892a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Fran=C3=A7ois=20Reboud?= Date: Fri, 8 Dec 2023 10:00:55 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20run=20on=20Apple=20Silico?= =?UTF-8?q?n=20(#110)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 1 + Sources/GrAIdient/Metal/MetalKernel.swift | 11 +++- Sources/GrAIdient/Utils/Image.swift | 72 +++++++++-------------- Tests/GrAIExamples/Base/Utils.swift | 2 +- Tests/GrAITorchTests/Base/Utils.swift | 2 +- 5 files changed, 41 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8aed98a3..ca6b982a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. ## [unreleased] +🐛 **fix:** run on Apple Silicon ([110](https://github.com/owkin/GrAIdient/pull/110))\ ⚙️ **core:** initForward,Backward model API ([109](https://github.com/owkin/GrAIdient/pull/109))\ 🪜 **layer_1d:** Dropout1D ([#108](https://github.com/owkin/GrAIdient/pull/108))\ 🪜 **feat:** VQGrad, VQGradSeq ([#107](https://github.com/owkin/GrAIdient/pull/107)) diff --git a/Sources/GrAIdient/Metal/MetalKernel.swift b/Sources/GrAIdient/Metal/MetalKernel.swift index 7228653c..5425b42c 100644 --- a/Sources/GrAIdient/Metal/MetalKernel.swift +++ b/Sources/GrAIdient/Metal/MetalKernel.swift @@ -969,7 +969,16 @@ public class MetalCommand public func setBytes(_ data: [T], atIndex index: Int) { let byteLength = data.count * MemoryLayout.size - _encoder.setBytes(data, length: byteLength, index: index) + data.withUnsafeBufferPointer + { + dataPtr in + + _encoder.setBytes( + UnsafeRawPointer(dataPtr.baseAddress)!, + length: byteLength, + index: index + ) + } } /// diff --git a/Sources/GrAIdient/Utils/Image.swift b/Sources/GrAIdient/Utils/Image.swift index 2450a321..9c24c81d 100644 --- a/Sources/GrAIdient/Utils/Image.swift +++ b/Sources/GrAIdient/Utils/Image.swift @@ -6,7 +6,7 @@ // import Foundation -import Cocoa +import AppKit /// Error occuring when processing images. public enum ImageError: Error @@ -107,42 +107,14 @@ public class Image let bufferPtr = metalBuffer.download() let nbImages = metalBuffer.nbElems / (width * height * 3) - var output = [[UInt8]]() - for elem in 0.. 255.0 - { - val = 255 - } - else - { - val = UInt8(valTmp) - } - - gridPtr[3 * offsetSet + depth] = val - }} - output.append(grid) + images.append([Float]( + bufferPtr[i * 3 * height * width..<(i+1) * 3 * height * width] + )) } - return output + return toRGB(toPixel(images), width: width, height: height) } /// @@ -157,7 +129,8 @@ public class Image var output = [[UInt8]]() for elem in 0.. [UInt8] { - if let imageData = tiffRepresentation, - let imageRep = NSBitmapImageRep(data: imageData), - let dataPtr = imageRep.bitmapData + if let pixelData = (cgImage( + forProposedRect: nil, context: nil, hints: nil)!).dataProvider?.data { - let bufferPtr = UnsafeBufferPointer( - start: dataPtr, - count: Int(3 * size.height * size.width) - ) - return [UInt8](bufferPtr) + let data: UnsafePointer = CFDataGetBytePtr(pixelData) + + var pixels = [UInt8]() + for i in 0..